Further database refactoring

More refactoring of the SQLAlchemy database layer to improve
compatility with eventlet on newer Pythons.

Inspired by 0ce2c41404

Related-Bug: 2067815
Change-Id: Ib5e9aa288232cc1b766bbf2a8ce2113d5a8e2f7d
(cherry picked from commit 753c44b0c4)
(cherry picked from commit 54b3b58428)
This commit is contained in:
James Page
2025-02-14 11:42:47 +00:00
committed by Alfredo Moralejo
parent accc7a2a22
commit 64ba589f80
8 changed files with 117 additions and 139 deletions

View File

@@ -11,12 +11,15 @@
# under the License.
from oslo_context import context
from oslo_db.sqlalchemy import enginefacade
from oslo_log import log
from oslo_utils import timeutils
LOG = log.getLogger(__name__)
@enginefacade.transaction_context_provider
class RequestContext(context.RequestContext):
"""Extends security contexts from the OpenStack common library."""

View File

@@ -13,8 +13,8 @@
from logging import config as log_config
from alembic import context
from oslo_db.sqlalchemy import enginefacade
from watcher.db.sqlalchemy import api as sqla_api
from watcher.db.sqlalchemy import models
# this is the Alembic Config object, which provides
@@ -43,7 +43,7 @@ def run_migrations_online():
and associate a connection with the context.
"""
engine = sqla_api.get_engine()
engine = enginefacade.writer.get_engine()
with engine.connect() as connection:
context.configure(connection=connection,
target_metadata=target_metadata)

View File

@@ -19,8 +19,10 @@
import collections
import datetime
import operator
import threading
from oslo_config import cfg
from oslo_db import api as oslo_db_api
from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import utils as db_utils
@@ -38,26 +40,7 @@ from watcher import objects
CONF = cfg.CONF
_FACADE = None
def _create_facade_lazily():
global _FACADE
if _FACADE is None:
ctx = enginefacade.transaction_context()
_FACADE = ctx.writer
return _FACADE
def get_engine():
facade = _create_facade_lazily()
return facade.get_engine()
def get_session(**kwargs):
facade = _create_facade_lazily()
sessionmaker = facade.get_sessionmaker()
return sessionmaker(**kwargs)
_CONTEXT = threading.local()
def get_backend():
@@ -65,14 +48,15 @@ def get_backend():
return Connection()
def model_query(model, *args, **kwargs):
"""Query helper for simpler session usage.
def _session_for_read():
return enginefacade.reader.using(_CONTEXT)
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
query = session.query(model, *args)
return query
# NOTE(tylerchristie) Please add @oslo_db_api.retry_on_deadlock decorator to
# any new methods using _session_for_write (as deadlocks happen on write), so
# that oslo_db is able to retry in case of deadlocks.
def _session_for_write():
return enginefacade.writer.using(_CONTEXT)
def add_identity_filter(query, value):
@@ -95,8 +79,6 @@ def add_identity_filter(query, value):
def _paginate_query(model, limit=None, marker=None, sort_key=None,
sort_dir=None, query=None):
if not query:
query = model_query(model)
sort_keys = ['id']
if sort_key and sort_key not in sort_keys:
sort_keys.insert(0, sort_key)
@@ -249,21 +231,20 @@ class Connection(api.BaseConnection):
query = query.options(joinedload(relationship.key))
return query
@oslo_db_api.retry_on_deadlock
def _create(self, model, values):
session = get_session()
with session.begin():
with _session_for_write() as session:
obj = model()
cleaned_values = {k: v for k, v in values.items()
if k not in self._get_relationships(model)}
obj.update(cleaned_values)
obj.save(session=session)
session.commit()
return obj
session.add(obj)
session.flush()
return obj
def _get(self, context, model, fieldname, value, eager):
session = get_session()
with session.begin():
query = model_query(model, session=session)
with _session_for_read() as session:
query = session.query(model)
if eager:
query = self._set_eager_options(model, query)
@@ -276,13 +257,13 @@ class Connection(api.BaseConnection):
except exc.NoResultFound:
raise exception.ResourceNotFound(name=model.__name__, id=value)
return obj
return obj
@staticmethod
@oslo_db_api.retry_on_deadlock
def _update(model, id_, values):
session = get_session()
with session.begin():
query = model_query(model, session=session)
with _session_for_write() as session:
query = session.query(model)
query = add_identity_filter(query, id_)
try:
ref = query.with_for_update().one()
@@ -290,13 +271,14 @@ class Connection(api.BaseConnection):
raise exception.ResourceNotFound(name=model.__name__, id=id_)
ref.update(values)
return ref
return ref
@staticmethod
@oslo_db_api.retry_on_deadlock
def _soft_delete(model, id_):
session = get_session()
with session.begin():
query = model_query(model, session=session)
with _session_for_write() as session:
query = session.query(model)
query = add_identity_filter(query, id_)
try:
row = query.one()
@@ -308,10 +290,10 @@ class Connection(api.BaseConnection):
return row
@staticmethod
@oslo_db_api.retry_on_deadlock
def _destroy(model, id_):
session = get_session()
with session.begin():
query = model_query(model, session=session)
with _session_for_write() as session:
query = session.query(model)
query = add_identity_filter(query, id_)
try:
@@ -324,14 +306,15 @@ class Connection(api.BaseConnection):
def _get_model_list(self, model, add_filters_func, context, filters=None,
limit=None, marker=None, sort_key=None, sort_dir=None,
eager=False):
query = model_query(model)
if eager:
query = self._set_eager_options(model, query)
query = add_filters_func(query, filters)
if not context.show_deleted:
query = query.filter(model.deleted_at.is_(None))
return _paginate_query(model, limit, marker,
sort_key, sort_dir, query)
with _session_for_read() as session:
query = session.query(model)
if eager:
query = self._set_eager_options(model, query)
query = add_filters_func(query, filters)
if not context.show_deleted:
query = query.filter(model.deleted_at.is_(None))
return _paginate_query(model, limit, marker,
sort_key, sort_dir, query)
# NOTE(erakli): _add_..._filters methods should be refactored to have same
# content. join_fieldmap should be filled with JoinMap instead of dict
@@ -426,11 +409,12 @@ class Connection(api.BaseConnection):
plain_fields=plain_fields, join_fieldmap=join_fieldmap)
if 'audit_uuid' in filters:
stmt = model_query(models.ActionPlan).join(
models.Audit,
models.Audit.id == models.ActionPlan.audit_id)\
.filter_by(uuid=filters['audit_uuid']).subquery()
query = query.filter_by(action_plan_id=stmt.c.id)
with _session_for_read() as session:
stmt = session.query(models.ActionPlan).join(
models.Audit,
models.Audit.id == models.ActionPlan.audit_id)\
.filter_by(uuid=filters['audit_uuid']).subquery()
query = query.filter_by(action_plan_id=stmt.c.id)
return query
@@ -608,20 +592,21 @@ class Connection(api.BaseConnection):
if not values.get('uuid'):
values['uuid'] = utils.generate_uuid()
query = model_query(models.AuditTemplate)
query = query.filter_by(name=values.get('name'),
deleted_at=None)
with _session_for_write() as session:
query = session.query(models.AuditTemplate)
query = query.filter_by(name=values.get('name'),
deleted_at=None)
if len(query.all()) > 0:
raise exception.AuditTemplateAlreadyExists(
audit_template=values['name'])
if len(query.all()) > 0:
raise exception.AuditTemplateAlreadyExists(
audit_template=values['name'])
try:
audit_template = self._create(models.AuditTemplate, values)
except db_exc.DBDuplicateEntry:
raise exception.AuditTemplateAlreadyExists(
audit_template=values['name'])
return audit_template
try:
audit_template = self._create(models.AuditTemplate, values)
except db_exc.DBDuplicateEntry:
raise exception.AuditTemplateAlreadyExists(
audit_template=values['name'])
return audit_template
def _get_audit_template(self, context, fieldname, value, eager):
try:
@@ -683,25 +668,26 @@ class Connection(api.BaseConnection):
if not values.get('uuid'):
values['uuid'] = utils.generate_uuid()
query = model_query(models.Audit)
query = query.filter_by(name=values.get('name'),
deleted_at=None)
with _session_for_write() as session:
query = session.query(models.Audit)
query = query.filter_by(name=values.get('name'),
deleted_at=None)
if len(query.all()) > 0:
raise exception.AuditAlreadyExists(
audit=values['name'])
if len(query.all()) > 0:
raise exception.AuditAlreadyExists(
audit=values['name'])
if values.get('state') is None:
values['state'] = objects.audit.State.PENDING
if values.get('state') is None:
values['state'] = objects.audit.State.PENDING
if not values.get('auto_trigger'):
values['auto_trigger'] = False
if not values.get('auto_trigger'):
values['auto_trigger'] = False
try:
audit = self._create(models.Audit, values)
except db_exc.DBDuplicateEntry:
raise exception.AuditAlreadyExists(audit=values['uuid'])
return audit
try:
audit = self._create(models.Audit, values)
except db_exc.DBDuplicateEntry:
raise exception.AuditAlreadyExists(audit=values['uuid'])
return audit
def _get_audit(self, context, fieldname, value, eager):
try:
@@ -725,14 +711,13 @@ class Connection(api.BaseConnection):
def destroy_audit(self, audit_id):
def is_audit_referenced(session, audit_id):
"""Checks whether the audit is referenced by action_plan(s)."""
query = model_query(models.ActionPlan, session=session)
query = session.query(models.ActionPlan)
query = self._add_action_plans_filters(
query, {'audit_id': audit_id})
return query.count() != 0
session = get_session()
with session.begin():
query = model_query(models.Audit, session=session)
with _session_for_write() as session:
query = session.query(models.Audit)
query = add_identity_filter(query, audit_id)
try:
@@ -799,9 +784,8 @@ class Connection(api.BaseConnection):
context, fieldname="uuid", value=action_uuid, eager=eager)
def destroy_action(self, action_id):
session = get_session()
with session.begin():
query = model_query(models.Action, session=session)
with _session_for_write() as session:
query = session.query(models.Action)
query = add_identity_filter(query, action_id)
count = query.delete()
if count != 1:
@@ -817,9 +801,8 @@ class Connection(api.BaseConnection):
@staticmethod
def _do_update_action(action_id, values):
session = get_session()
with session.begin():
query = model_query(models.Action, session=session)
with _session_for_write() as session:
query = session.query(models.Action)
query = add_identity_filter(query, action_id)
try:
ref = query.with_for_update().one()
@@ -827,7 +810,7 @@ class Connection(api.BaseConnection):
raise exception.ActionNotFound(action=action_id)
ref.update(values)
return ref
return ref
def soft_delete_action(self, action_id):
try:
@@ -871,14 +854,13 @@ class Connection(api.BaseConnection):
def destroy_action_plan(self, action_plan_id):
def is_action_plan_referenced(session, action_plan_id):
"""Checks whether the action_plan is referenced by action(s)."""
query = model_query(models.Action, session=session)
query = session.query(models.Action)
query = self._add_actions_filters(
query, {'action_plan_id': action_plan_id})
return query.count() != 0
session = get_session()
with session.begin():
query = model_query(models.ActionPlan, session=session)
with _session_for_write() as session:
query = session.query(models.ActionPlan)
query = add_identity_filter(query, action_plan_id)
try:
@@ -902,9 +884,8 @@ class Connection(api.BaseConnection):
@staticmethod
def _do_update_action_plan(action_plan_id, values):
session = get_session()
with session.begin():
query = model_query(models.ActionPlan, session=session)
with _session_for_write() as session:
query = session.query(models.ActionPlan)
query = add_identity_filter(query, action_plan_id)
try:
ref = query.with_for_update().one()
@@ -912,7 +893,7 @@ class Connection(api.BaseConnection):
raise exception.ActionPlanNotFound(action_plan=action_plan_id)
ref.update(values)
return ref
return ref
def soft_delete_action_plan(self, action_plan_id):
try:

View File

@@ -20,9 +20,9 @@ import alembic
from alembic import config as alembic_config
import alembic.migration as alembic_migration
from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import enginefacade
from watcher._i18n import _
from watcher.db.sqlalchemy import api as sqla_api
from watcher.db.sqlalchemy import models
@@ -39,7 +39,7 @@ def version(engine=None):
:rtype: string
"""
if engine is None:
engine = sqla_api.get_engine()
engine = enginefacade.reader.get_engine()
with engine.connect() as conn:
context = alembic_migration.MigrationContext.configure(conn)
return context.get_current_revision()
@@ -63,7 +63,7 @@ def create_schema(config=None, engine=None):
Can be used for initial installation instead of upgrade('head').
"""
if engine is None:
engine = sqla_api.get_engine()
engine = enginefacade.writer.get_engine()
# NOTE(viktors): If we will use metadata.create_all() for non empty db
# schema, it will only add the new tables, but leave

View File

@@ -93,14 +93,6 @@ class WatcherBase(models.SoftDeleteMixin,
d[c.name] = self[c.name]
return d
def save(self, session=None):
import watcher.db.sqlalchemy.api as db_api
if session is None:
session = db_api.get_session()
super(WatcherBase, self).save(session)
Base = declarative_base(cls=WatcherBase)

View File

@@ -52,7 +52,7 @@ class ContinuousAuditHandler(base.AuditHandler):
self._audit_scheduler = scheduling.BackgroundSchedulerService(
jobstores={
'default': job_store.WatcherJobStore(
engine=sq_api.get_engine()),
engine=sq_api.enginefacade.writer.get_engine()),
}
)
return self._audit_scheduler

View File

@@ -17,9 +17,10 @@
import fixtures
from oslo_config import cfg
from oslo_db.sqlalchemy import enginefacade
from watcher.db import api as dbapi
from watcher.db.sqlalchemy import api as sqla_api
from watcher.db.sqlalchemy import migration
from watcher.db.sqlalchemy import models
from watcher.tests import base
@@ -35,16 +36,16 @@ _DB_CACHE = None
class Database(fixtures.Fixture):
def __init__(self, db_api, db_migrate, sql_connection):
def __init__(self, engine, db_migrate, sql_connection):
self.sql_connection = sql_connection
self.engine = db_api.get_engine()
self.engine = engine
self.engine.dispose()
conn = self.engine.connect()
self.setup_sqlite(db_migrate)
self.post_migrations()
self._DB = "".join(line for line in conn.connection.iterdump())
with self.engine.connect() as conn:
self.setup_sqlite(db_migrate)
self.post_migrations()
self._DB = "".join(line for line in conn.connection.iterdump())
self.engine.dispose()
def setup_sqlite(self, db_migrate):
@@ -55,9 +56,8 @@ class Database(fixtures.Fixture):
def setUp(self):
super(Database, self).setUp()
conn = self.engine.connect()
conn.connection.executescript(self._DB)
with self.engine.connect() as conn:
conn.connection.executescript(self._DB)
self.addCleanup(self.engine.dispose)
def post_migrations(self):
@@ -80,7 +80,9 @@ class DbTestCase(base.TestCase):
global _DB_CACHE
if not _DB_CACHE:
_DB_CACHE = Database(sqla_api, migration,
engine = enginefacade.writer.get_engine()
_DB_CACHE = Database(engine, migration,
sql_connection=CONF.database.connection)
engine.dispose()
self.useFixture(_DB_CACHE)
self._id_gen = utils.id_generator()

View File

@@ -263,7 +263,7 @@ class TestContinuousAuditHandler(base.DbTestCase):
cfg.CONF.set_override("host", "hostname1")
@mock.patch.object(objects.service.Service, 'list')
@mock.patch.object(sq_api, 'get_engine')
@mock.patch.object(sq_api.enginefacade.writer, 'get_engine')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')
@mock.patch.object(objects.audit.Audit, 'list')
@@ -286,7 +286,7 @@ class TestContinuousAuditHandler(base.DbTestCase):
self.assertIsNone(self.audits[1].next_run_time)
@mock.patch.object(objects.service.Service, 'list')
@mock.patch.object(sq_api, 'get_engine')
@mock.patch.object(sq_api.enginefacade.writer, 'get_engine')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')
@mock.patch.object(objects.audit.Audit, 'list')
@@ -309,7 +309,7 @@ class TestContinuousAuditHandler(base.DbTestCase):
@mock.patch.object(continuous.ContinuousAuditHandler, '_next_cron_time')
@mock.patch.object(objects.service.Service, 'list')
@mock.patch.object(sq_api, 'get_engine')
@mock.patch.object(sq_api.enginefacade.writer, 'get_engine')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')
@mock.patch.object(objects.audit.Audit, 'list')
@@ -328,7 +328,7 @@ class TestContinuousAuditHandler(base.DbTestCase):
audit_handler.launch_audits_periodically)
@mock.patch.object(objects.service.Service, 'list')
@mock.patch.object(sq_api, 'get_engine')
@mock.patch.object(sq_api.enginefacade.writer, 'get_engine')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')
@mock.patch.object(objects.audit.Audit, 'list')
@@ -349,7 +349,7 @@ class TestContinuousAuditHandler(base.DbTestCase):
m_add_job.assert_has_calls(calls)
@mock.patch.object(objects.service.Service, 'list')
@mock.patch.object(sq_api, 'get_engine')
@mock.patch.object(sq_api.enginefacade.writer, 'get_engine')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')
@mock.patch.object(objects.audit.Audit, 'list')
@@ -384,7 +384,7 @@ class TestContinuousAuditHandler(base.DbTestCase):
self.assertTrue(is_inactive)
@mock.patch.object(objects.service.Service, 'list')
@mock.patch.object(sq_api, 'get_engine')
@mock.patch.object(sq_api.enginefacade.writer, 'get_engine')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')
@mock.patch.object(objects.audit.AuditStateTransitionManager,
'is_inactive')
@@ -406,7 +406,7 @@ class TestContinuousAuditHandler(base.DbTestCase):
self.assertIsNotNone(self.audits[0].next_run_time)
@mock.patch.object(objects.service.Service, 'list')
@mock.patch.object(sq_api, 'get_engine')
@mock.patch.object(sq_api.enginefacade.writer, 'get_engine')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'remove_job')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'add_job')
@mock.patch.object(scheduling.BackgroundSchedulerService, 'get_jobs')