From 753c44b0c465f018d7c61f5ffce24b603e2f12c0 Mon Sep 17 00:00:00 2001 From: James Page Date: Fri, 14 Feb 2025 11:42:47 +0000 Subject: [PATCH] Further database refactoring More refactoring of the SQLAlchemy database layer to improve compatility with eventlet on newer Pythons. Inspired by https://opendev.org/openstack/magnum/commit/0ce2c41404f1f8dcd1bcd19d36a885edc34926a2 Related-Bug: 2067815 Change-Id: Ib5e9aa288232cc1b766bbf2a8ce2113d5a8e2f7d --- watcher/common/context.py | 3 + watcher/db/sqlalchemy/alembic/env.py | 4 +- watcher/db/sqlalchemy/api.py | 195 ++++++++---------- watcher/db/sqlalchemy/migration.py | 6 +- watcher/db/sqlalchemy/models.py | 8 - watcher/decision_engine/audit/continuous.py | 2 +- watcher/tests/db/base.py | 24 ++- .../audit/test_audit_handlers.py | 14 +- 8 files changed, 117 insertions(+), 139 deletions(-) diff --git a/watcher/common/context.py b/watcher/common/context.py index 7a82cad8c..32dba452b 100644 --- a/watcher/common/context.py +++ b/watcher/common/context.py @@ -11,12 +11,15 @@ # under the License. from oslo_context import context +from oslo_db.sqlalchemy import enginefacade from oslo_log import log from oslo_utils import timeutils + LOG = log.getLogger(__name__) +@enginefacade.transaction_context_provider class RequestContext(context.RequestContext): """Extends security contexts from the OpenStack common library.""" diff --git a/watcher/db/sqlalchemy/alembic/env.py b/watcher/db/sqlalchemy/alembic/env.py index 474b1ca66..9e4c9cdc9 100644 --- a/watcher/db/sqlalchemy/alembic/env.py +++ b/watcher/db/sqlalchemy/alembic/env.py @@ -13,8 +13,8 @@ from logging import config as log_config from alembic import context +from oslo_db.sqlalchemy import enginefacade -from watcher.db.sqlalchemy import api as sqla_api from watcher.db.sqlalchemy import models # this is the Alembic Config object, which provides @@ -43,7 +43,7 @@ def run_migrations_online(): and associate a connection with the context. """ - engine = sqla_api.get_engine() + engine = enginefacade.writer.get_engine() with engine.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/watcher/db/sqlalchemy/api.py b/watcher/db/sqlalchemy/api.py index dad079172..33ed128bf 100644 --- a/watcher/db/sqlalchemy/api.py +++ b/watcher/db/sqlalchemy/api.py @@ -19,8 +19,10 @@ import collections import datetime import operator +import threading from oslo_config import cfg +from oslo_db import api as oslo_db_api from oslo_db import exception as db_exc from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import utils as db_utils @@ -38,26 +40,7 @@ from watcher import objects CONF = cfg.CONF -_FACADE = None - - -def _create_facade_lazily(): - global _FACADE - if _FACADE is None: - ctx = enginefacade.transaction_context() - _FACADE = ctx.writer - return _FACADE - - -def get_engine(): - facade = _create_facade_lazily() - return facade.get_engine() - - -def get_session(**kwargs): - facade = _create_facade_lazily() - sessionmaker = facade.get_sessionmaker() - return sessionmaker(**kwargs) +_CONTEXT = threading.local() def get_backend(): @@ -65,14 +48,15 @@ def get_backend(): return Connection() -def model_query(model, *args, **kwargs): - """Query helper for simpler session usage. +def _session_for_read(): + return enginefacade.reader.using(_CONTEXT) - :param session: if present, the session to use - """ - session = kwargs.get('session') or get_session() - query = session.query(model, *args) - return query + +# NOTE(tylerchristie) Please add @oslo_db_api.retry_on_deadlock decorator to +# any new methods using _session_for_write (as deadlocks happen on write), so +# that oslo_db is able to retry in case of deadlocks. +def _session_for_write(): + return enginefacade.writer.using(_CONTEXT) def add_identity_filter(query, value): @@ -95,8 +79,6 @@ def add_identity_filter(query, value): def _paginate_query(model, limit=None, marker=None, sort_key=None, sort_dir=None, query=None): - if not query: - query = model_query(model) sort_keys = ['id'] if sort_key and sort_key not in sort_keys: sort_keys.insert(0, sort_key) @@ -250,21 +232,20 @@ class Connection(api.BaseConnection): getattr(model, relationship.key))) return query + @oslo_db_api.retry_on_deadlock def _create(self, model, values): - session = get_session() - with session.begin(): + with _session_for_write() as session: obj = model() cleaned_values = {k: v for k, v in values.items() if k not in self._get_relationships(model)} obj.update(cleaned_values) - obj.save(session=session) - session.commit() - return obj + session.add(obj) + session.flush() + return obj def _get(self, context, model, fieldname, value, eager): - session = get_session() - with session.begin(): - query = model_query(model, session=session) + with _session_for_read() as session: + query = session.query(model) if eager: query = self._set_eager_options(model, query) @@ -277,13 +258,13 @@ class Connection(api.BaseConnection): except exc.NoResultFound: raise exception.ResourceNotFound(name=model.__name__, id=value) - return obj + return obj @staticmethod + @oslo_db_api.retry_on_deadlock def _update(model, id_, values): - session = get_session() - with session.begin(): - query = model_query(model, session=session) + with _session_for_write() as session: + query = session.query(model) query = add_identity_filter(query, id_) try: ref = query.with_for_update().one() @@ -291,13 +272,14 @@ class Connection(api.BaseConnection): raise exception.ResourceNotFound(name=model.__name__, id=id_) ref.update(values) - return ref + + return ref @staticmethod + @oslo_db_api.retry_on_deadlock def _soft_delete(model, id_): - session = get_session() - with session.begin(): - query = model_query(model, session=session) + with _session_for_write() as session: + query = session.query(model) query = add_identity_filter(query, id_) try: row = query.one() @@ -309,10 +291,10 @@ class Connection(api.BaseConnection): return row @staticmethod + @oslo_db_api.retry_on_deadlock def _destroy(model, id_): - session = get_session() - with session.begin(): - query = model_query(model, session=session) + with _session_for_write() as session: + query = session.query(model) query = add_identity_filter(query, id_) try: @@ -325,14 +307,15 @@ class Connection(api.BaseConnection): def _get_model_list(self, model, add_filters_func, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None, eager=False): - query = model_query(model) - if eager: - query = self._set_eager_options(model, query) - query = add_filters_func(query, filters) - if not context.show_deleted: - query = query.filter(model.deleted_at.is_(None)) - return _paginate_query(model, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(model) + if eager: + query = self._set_eager_options(model, query) + query = add_filters_func(query, filters) + if not context.show_deleted: + query = query.filter(model.deleted_at.is_(None)) + return _paginate_query(model, limit, marker, + sort_key, sort_dir, query) # NOTE(erakli): _add_..._filters methods should be refactored to have same # content. join_fieldmap should be filled with JoinMap instead of dict @@ -427,11 +410,12 @@ class Connection(api.BaseConnection): plain_fields=plain_fields, join_fieldmap=join_fieldmap) if 'audit_uuid' in filters: - stmt = model_query(models.ActionPlan).join( - models.Audit, - models.Audit.id == models.ActionPlan.audit_id)\ - .filter_by(uuid=filters['audit_uuid']).subquery() - query = query.filter_by(action_plan_id=stmt.c.id) + with _session_for_read() as session: + stmt = session.query(models.ActionPlan).join( + models.Audit, + models.Audit.id == models.ActionPlan.audit_id)\ + .filter_by(uuid=filters['audit_uuid']).subquery() + query = query.filter_by(action_plan_id=stmt.c.id) return query @@ -609,20 +593,21 @@ class Connection(api.BaseConnection): if not values.get('uuid'): values['uuid'] = utils.generate_uuid() - query = model_query(models.AuditTemplate) - query = query.filter_by(name=values.get('name'), - deleted_at=None) + with _session_for_write() as session: + query = session.query(models.AuditTemplate) + query = query.filter_by(name=values.get('name'), + deleted_at=None) - if len(query.all()) > 0: - raise exception.AuditTemplateAlreadyExists( - audit_template=values['name']) + if len(query.all()) > 0: + raise exception.AuditTemplateAlreadyExists( + audit_template=values['name']) - try: - audit_template = self._create(models.AuditTemplate, values) - except db_exc.DBDuplicateEntry: - raise exception.AuditTemplateAlreadyExists( - audit_template=values['name']) - return audit_template + try: + audit_template = self._create(models.AuditTemplate, values) + except db_exc.DBDuplicateEntry: + raise exception.AuditTemplateAlreadyExists( + audit_template=values['name']) + return audit_template def _get_audit_template(self, context, fieldname, value, eager): try: @@ -684,25 +669,26 @@ class Connection(api.BaseConnection): if not values.get('uuid'): values['uuid'] = utils.generate_uuid() - query = model_query(models.Audit) - query = query.filter_by(name=values.get('name'), - deleted_at=None) + with _session_for_write() as session: + query = session.query(models.Audit) + query = query.filter_by(name=values.get('name'), + deleted_at=None) - if len(query.all()) > 0: - raise exception.AuditAlreadyExists( - audit=values['name']) + if len(query.all()) > 0: + raise exception.AuditAlreadyExists( + audit=values['name']) - if values.get('state') is None: - values['state'] = objects.audit.State.PENDING + if values.get('state') is None: + values['state'] = objects.audit.State.PENDING - if not values.get('auto_trigger'): - values['auto_trigger'] = False + if not values.get('auto_trigger'): + values['auto_trigger'] = False - try: - audit = self._create(models.Audit, values) - except db_exc.DBDuplicateEntry: - raise exception.AuditAlreadyExists(audit=values['uuid']) - return audit + try: + audit = self._create(models.Audit, values) + except db_exc.DBDuplicateEntry: + raise exception.AuditAlreadyExists(audit=values['uuid']) + return audit def _get_audit(self, context, fieldname, value, eager): try: @@ -726,14 +712,13 @@ class Connection(api.BaseConnection): def destroy_audit(self, audit_id): def is_audit_referenced(session, audit_id): """Checks whether the audit is referenced by action_plan(s).""" - query = model_query(models.ActionPlan, session=session) + query = session.query(models.ActionPlan) query = self._add_action_plans_filters( query, {'audit_id': audit_id}) return query.count() != 0 - session = get_session() - with session.begin(): - query = model_query(models.Audit, session=session) + with _session_for_write() as session: + query = session.query(models.Audit) query = add_identity_filter(query, audit_id) try: @@ -800,9 +785,8 @@ class Connection(api.BaseConnection): context, fieldname="uuid", value=action_uuid, eager=eager) def destroy_action(self, action_id): - session = get_session() - with session.begin(): - query = model_query(models.Action, session=session) + with _session_for_write() as session: + query = session.query(models.Action) query = add_identity_filter(query, action_id) count = query.delete() if count != 1: @@ -818,9 +802,8 @@ class Connection(api.BaseConnection): @staticmethod def _do_update_action(action_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Action, session=session) + with _session_for_write() as session: + query = session.query(models.Action) query = add_identity_filter(query, action_id) try: ref = query.with_for_update().one() @@ -828,7 +811,7 @@ class Connection(api.BaseConnection): raise exception.ActionNotFound(action=action_id) ref.update(values) - return ref + return ref def soft_delete_action(self, action_id): try: @@ -872,14 +855,13 @@ class Connection(api.BaseConnection): def destroy_action_plan(self, action_plan_id): def is_action_plan_referenced(session, action_plan_id): """Checks whether the action_plan is referenced by action(s).""" - query = model_query(models.Action, session=session) + query = session.query(models.Action) query = self._add_actions_filters( query, {'action_plan_id': action_plan_id}) return query.count() != 0 - session = get_session() - with session.begin(): - query = model_query(models.ActionPlan, session=session) + with _session_for_write() as session: + query = session.query(models.ActionPlan) query = add_identity_filter(query, action_plan_id) try: @@ -903,9 +885,8 @@ class Connection(api.BaseConnection): @staticmethod def _do_update_action_plan(action_plan_id, values): - session = get_session() - with session.begin(): - query = model_query(models.ActionPlan, session=session) + with _session_for_write() as session: + query = session.query(models.ActionPlan) query = add_identity_filter(query, action_plan_id) try: ref = query.with_for_update().one() @@ -913,7 +894,7 @@ class Connection(api.BaseConnection): raise exception.ActionPlanNotFound(action_plan=action_plan_id) ref.update(values) - return ref + return ref def soft_delete_action_plan(self, action_plan_id): try: diff --git a/watcher/db/sqlalchemy/migration.py b/watcher/db/sqlalchemy/migration.py index 415214664..adf74b8d5 100644 --- a/watcher/db/sqlalchemy/migration.py +++ b/watcher/db/sqlalchemy/migration.py @@ -20,9 +20,9 @@ import alembic from alembic import config as alembic_config import alembic.migration as alembic_migration from oslo_db import exception as db_exc +from oslo_db.sqlalchemy import enginefacade from watcher._i18n import _ -from watcher.db.sqlalchemy import api as sqla_api from watcher.db.sqlalchemy import models @@ -39,7 +39,7 @@ def version(engine=None): :rtype: string """ if engine is None: - engine = sqla_api.get_engine() + engine = enginefacade.reader.get_engine() with engine.connect() as conn: context = alembic_migration.MigrationContext.configure(conn) return context.get_current_revision() @@ -63,7 +63,7 @@ def create_schema(config=None, engine=None): Can be used for initial installation instead of upgrade('head'). """ if engine is None: - engine = sqla_api.get_engine() + engine = enginefacade.writer.get_engine() # NOTE(viktors): If we will use metadata.create_all() for non empty db # schema, it will only add the new tables, but leave diff --git a/watcher/db/sqlalchemy/models.py b/watcher/db/sqlalchemy/models.py index 55e60bca8..6ae06378f 100644 --- a/watcher/db/sqlalchemy/models.py +++ b/watcher/db/sqlalchemy/models.py @@ -93,14 +93,6 @@ class WatcherBase(models.SoftDeleteMixin, d[c.name] = self[c.name] return d - def save(self, session=None): - import watcher.db.sqlalchemy.api as db_api - - if session is None: - session = db_api.get_session() - - super(WatcherBase, self).save(session) - Base = declarative_base(cls=WatcherBase) diff --git a/watcher/decision_engine/audit/continuous.py b/watcher/decision_engine/audit/continuous.py index cc4b14b8f..e9ae1bec1 100644 --- a/watcher/decision_engine/audit/continuous.py +++ b/watcher/decision_engine/audit/continuous.py @@ -53,7 +53,7 @@ class ContinuousAuditHandler(base.AuditHandler): self._audit_scheduler = scheduling.BackgroundSchedulerService( jobstores={ 'default': job_store.WatcherJobStore( - engine=sq_api.get_engine()), + engine=sq_api.enginefacade.writer.get_engine()), } ) return self._audit_scheduler diff --git a/watcher/tests/db/base.py b/watcher/tests/db/base.py index 57d7a610e..4cf25eac0 100644 --- a/watcher/tests/db/base.py +++ b/watcher/tests/db/base.py @@ -17,9 +17,10 @@ import fixtures from oslo_config import cfg +from oslo_db.sqlalchemy import enginefacade + from watcher.db import api as dbapi -from watcher.db.sqlalchemy import api as sqla_api from watcher.db.sqlalchemy import migration from watcher.db.sqlalchemy import models from watcher.tests import base @@ -35,16 +36,16 @@ _DB_CACHE = None class Database(fixtures.Fixture): - def __init__(self, db_api, db_migrate, sql_connection): + def __init__(self, engine, db_migrate, sql_connection): self.sql_connection = sql_connection - self.engine = db_api.get_engine() + self.engine = engine self.engine.dispose() - conn = self.engine.connect() - self.setup_sqlite(db_migrate) - self.post_migrations() - self._DB = "".join(line for line in conn.connection.iterdump()) + with self.engine.connect() as conn: + self.setup_sqlite(db_migrate) + self.post_migrations() + self._DB = "".join(line for line in conn.connection.iterdump()) self.engine.dispose() def setup_sqlite(self, db_migrate): @@ -55,9 +56,8 @@ class Database(fixtures.Fixture): def setUp(self): super(Database, self).setUp() - - conn = self.engine.connect() - conn.connection.executescript(self._DB) + with self.engine.connect() as conn: + conn.connection.executescript(self._DB) self.addCleanup(self.engine.dispose) def post_migrations(self): @@ -80,7 +80,9 @@ class DbTestCase(base.TestCase): global _DB_CACHE if not _DB_CACHE: - _DB_CACHE = Database(sqla_api, migration, + engine = enginefacade.writer.get_engine() + _DB_CACHE = Database(engine, migration, sql_connection=CONF.database.connection) + engine.dispose() self.useFixture(_DB_CACHE) self._id_gen = utils.id_generator() diff --git a/watcher/tests/decision_engine/audit/test_audit_handlers.py b/watcher/tests/decision_engine/audit/test_audit_handlers.py index 58becc3a8..79f06976e 100644 --- a/watcher/tests/decision_engine/audit/test_audit_handlers.py +++ b/watcher/tests/decision_engine/audit/test_audit_handlers.py @@ -264,7 +264,7 @@ class TestContinuousAuditHandler(base.DbTestCase): cfg.CONF.set_override("host", "hostname1") @mock.patch.object(objects.service.Service, 'list') - @mock.patch.object(sq_api, 'get_engine') + @mock.patch.object(sq_api.enginefacade.writer, 'get_engine') @mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job') @mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs') @mock.patch.object(objects.audit.Audit, 'list') @@ -287,7 +287,7 @@ class TestContinuousAuditHandler(base.DbTestCase): self.assertIsNone(self.audits[1].next_run_time) @mock.patch.object(objects.service.Service, 'list') - @mock.patch.object(sq_api, 'get_engine') + @mock.patch.object(sq_api.enginefacade.writer, 'get_engine') @mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job') @mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs') @mock.patch.object(objects.audit.Audit, 'list') @@ -310,7 +310,7 @@ class TestContinuousAuditHandler(base.DbTestCase): @mock.patch.object(continuous.ContinuousAuditHandler, '_next_cron_time') @mock.patch.object(objects.service.Service, 'list') - @mock.patch.object(sq_api, 'get_engine') + @mock.patch.object(sq_api.enginefacade.writer, 'get_engine') @mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job') @mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs') @mock.patch.object(objects.audit.Audit, 'list') @@ -329,7 +329,7 @@ class TestContinuousAuditHandler(base.DbTestCase): audit_handler.launch_audits_periodically) @mock.patch.object(objects.service.Service, 'list') - @mock.patch.object(sq_api, 'get_engine') + @mock.patch.object(sq_api.enginefacade.writer, 'get_engine') @mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job') @mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs') @mock.patch.object(objects.audit.Audit, 'list') @@ -350,7 +350,7 @@ class TestContinuousAuditHandler(base.DbTestCase): m_add_job.assert_has_calls(calls) @mock.patch.object(objects.service.Service, 'list') - @mock.patch.object(sq_api, 'get_engine') + @mock.patch.object(sq_api.enginefacade.writer, 'get_engine') @mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job') @mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs') @mock.patch.object(objects.audit.Audit, 'list') @@ -385,7 +385,7 @@ class TestContinuousAuditHandler(base.DbTestCase): self.assertTrue(is_inactive) @mock.patch.object(objects.service.Service, 'list') - @mock.patch.object(sq_api, 'get_engine') + @mock.patch.object(sq_api.enginefacade.writer, 'get_engine') @mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs') @mock.patch.object(objects.audit.AuditStateTransitionManager, 'is_inactive') @@ -407,7 +407,7 @@ class TestContinuousAuditHandler(base.DbTestCase): self.assertIsNotNone(self.audits[0].next_run_time) @mock.patch.object(objects.service.Service, 'list') - @mock.patch.object(sq_api, 'get_engine') + @mock.patch.object(sq_api.enginefacade.writer, 'get_engine') @mock.patch.object(scheduling.BackgroundSchedulerService, 'remove_job') @mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job') @mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')