diff --git a/watcher/common/exception.py b/watcher/common/exception.py index 8c517150f..370b7d071 100644 --- a/watcher/common/exception.py +++ b/watcher/common/exception.py @@ -150,6 +150,11 @@ class InvalidIdentity(Invalid): msg_fmt = _("Expected a uuid or int but received %(identity)s") +class InvalidOperator(Invalid): + msg_fmt = _("Filter operator is not valid: %(operator)s not " + "in %(valid_operators)s") + + class InvalidGoal(Invalid): msg_fmt = _("Goal %(goal)s is invalid") diff --git a/watcher/db/api.py b/watcher/db/api.py index 1416b2fe1..1e276493f 100644 --- a/watcher/db/api.py +++ b/watcher/db/api.py @@ -139,8 +139,6 @@ class BaseConnection(object): match the specified filters. :param context: The security context - :param columns: List of column names to return. - Defaults to 'id' column when columns == None. :param filters: Filters to apply. Defaults to None. :param limit: Maximum number of strategies to return. @@ -221,7 +219,7 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_audit_template_list(self, context, columns=None, filters=None, + def get_audit_template_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): """Get specific columns for matching audit templates. @@ -230,8 +228,6 @@ class BaseConnection(object): match the specified filters. :param context: The security context - :param columns: List of column names to return. - Defaults to 'id' column when columns == None. :param filters: Filters to apply. Defaults to None. :param limit: Maximum number of audit templates to return. @@ -320,7 +316,7 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_audit_list(self, context, columns=None, filters=None, limit=None, + def get_audit_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): """Get specific columns for matching audits. @@ -328,8 +324,6 @@ class BaseConnection(object): specified filters. :param context: The security context - :param columns: List of column names to return. - Defaults to 'id' column when columns == None. :param filters: Filters to apply. Defaults to None. :param limit: Maximum number of audits to return. @@ -407,7 +401,7 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_action_list(self, context, columns=None, filters=None, limit=None, + def get_action_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): """Get specific columns for matching actions. @@ -415,8 +409,6 @@ class BaseConnection(object): specified filters. :param context: The security context - :param columns: List of column names to return. - Defaults to 'id' column when columns == None. :param filters: Filters to apply. Defaults to None. :param limit: Maximum number of actions to return. @@ -490,7 +482,7 @@ class BaseConnection(object): @abc.abstractmethod def get_action_plan_list( - self, context, columns=None, filters=None, limit=None, + self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): """Get specific columns for matching action plans. @@ -498,8 +490,6 @@ class BaseConnection(object): match the specified filters. :param context: The security context - :param columns: List of column names to return. - Defaults to 'id' column when columns == None. :param filters: Filters to apply. Defaults to None. :param limit: Maximum number of audits to return. diff --git a/watcher/db/sqlalchemy/api.py b/watcher/db/sqlalchemy/api.py index 0bb1b64db..874e2c931 100644 --- a/watcher/db/sqlalchemy/api.py +++ b/watcher/db/sqlalchemy/api.py @@ -14,10 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. - """SQLAlchemy storage backend.""" import collections +import datetime +import operator from oslo_config import cfg from oslo_db import exception as db_exc @@ -25,7 +26,7 @@ from oslo_db.sqlalchemy import session as db_session from oslo_db.sqlalchemy import utils as db_utils from sqlalchemy.orm import exc -from watcher import _i18n +from watcher._i18n import _ from watcher.common import exception from watcher.common import utils from watcher.db import api @@ -36,7 +37,6 @@ from watcher.objects import audit as audit_objects from watcher.objects import utils as objutils CONF = cfg.CONF -_ = _i18n._ _FACADE = None @@ -115,117 +115,121 @@ NaturalJoinFilter = collections.namedtuple( class Connection(api.BaseConnection): """SqlAlchemy connection.""" + valid_operators = { + "": operator.eq, + "eq": operator.eq, + "neq": operator.ne, + "gt": operator.gt, + "gte": operator.ge, + "lt": operator.lt, + "lte": operator.le, + "in": lambda field, choices: field.in_(choices), + "notin": lambda field, choices: field.notin_(choices), + } + def __init__(self): super(Connection, self).__init__() - def __add_soft_delete_mixin_filters(self, query, filters, model): - if 'deleted' in filters: - if bool(filters['deleted']): - query = query.filter(model.deleted != 0) - else: - query = query.filter(model.deleted == 0) - if 'deleted_at__eq' in filters: - query = query.filter( - model.deleted_at == objutils.datetime_or_str_or_none( - filters['deleted_at__eq'])) - if 'deleted_at__gt' in filters: - query = query.filter( - model.deleted_at > objutils.datetime_or_str_or_none( - filters['deleted_at__gt'])) - if 'deleted_at__gte' in filters: - query = query.filter( - model.deleted_at >= objutils.datetime_or_str_or_none( - filters['deleted_at__gte'])) - if 'deleted_at__lt' in filters: - query = query.filter( - model.deleted_at < objutils.datetime_or_str_or_none( - filters['deleted_at__lt'])) - if 'deleted_at__lte' in filters: - query = query.filter( - model.deleted_at <= objutils.datetime_or_str_or_none( - filters['deleted_at__lte'])) + def __add_simple_filter(self, query, model, fieldname, value, operator_): + field = getattr(model, fieldname) - return query + if field.type.python_type is datetime.datetime: + value = objutils.datetime_or_str_or_none(value) - def __add_timestamp_mixin_filters(self, query, filters, model): - if 'created_at__eq' in filters: - query = query.filter( - model.created_at == objutils.datetime_or_str_or_none( - filters['created_at__eq'])) - if 'created_at__gt' in filters: - query = query.filter( - model.created_at > objutils.datetime_or_str_or_none( - filters['created_at__gt'])) - if 'created_at__gte' in filters: - query = query.filter( - model.created_at >= objutils.datetime_or_str_or_none( - filters['created_at__gte'])) - if 'created_at__lt' in filters: - query = query.filter( - model.created_at < objutils.datetime_or_str_or_none( - filters['created_at__lt'])) - if 'created_at__lte' in filters: - query = query.filter( - model.created_at <= objutils.datetime_or_str_or_none( - filters['created_at__lte'])) + return query.filter(self.valid_operators[operator_](field, value)) - if 'updated_at__eq' in filters: - query = query.filter( - model.updated_at == objutils.datetime_or_str_or_none( - filters['updated_at__eq'])) - if 'updated_at__gt' in filters: - query = query.filter( - model.updated_at > objutils.datetime_or_str_or_none( - filters['updated_at__gt'])) - if 'updated_at__gte' in filters: - query = query.filter( - model.updated_at >= objutils.datetime_or_str_or_none( - filters['updated_at__gte'])) - if 'updated_at__lt' in filters: - query = query.filter( - model.updated_at < objutils.datetime_or_str_or_none( - filters['updated_at__lt'])) - if 'updated_at__lte' in filters: - query = query.filter( - model.updated_at <= objutils.datetime_or_str_or_none( - filters['updated_at__lte'])) + def __add_join_filter(self, query, model, fieldname, value, operator_): + query = query.join(model) + return self.__add_simple_filter(query, model, fieldname, + value, operator_) - return query + def __decompose_filter(self, raw_fieldname): + """Decompose a filter name into its 2 subparts - def __add_simple_filter(self, query, model, fieldname, value): - return query.filter(getattr(model, fieldname) == value) + A filter can take 2 forms: - def __add_natural_join_filter(self, query, join_model, - join_fieldname, value): - query = query.join(join_model) - return self.__add_simple_filter( - query, join_model, join_fieldname, value) + - "" which is a syntactic sugar for "__eq" + - "__" where is the comparison operator + to be used. + + Available operators are: + + - eq + - neq + - gt + - gte + - lt + - lte + - in + - notin + """ + separator = '__' + fieldname, separator, operator_ = raw_fieldname.partition(separator) + + if operator_ and operator_ not in self.valid_operators: + raise exception.InvalidOperator( + operator=operator_, valid_operators=self.valid_operators) + + return fieldname, operator_ def _add_filters(self, query, model, filters=None, plain_fields=None, join_fieldmap=None): """Generic way to add filters to a Watcher model + Each filter key provided by the `filters` parameter will be decomposed + into 2 pieces: the field name and the comparison operator + + - "": By default, the "eq" is applied if no operator is provided + - "eq", which stands for "equal" : e.g. {"state__eq": "PENDING"} + will result in the "WHERE state = 'PENDING'" clause. + - "neq", which stands for "not equal" : e.g. {"state__neq": "PENDING"} + will result in the "WHERE state != 'PENDING'" clause. + - "gt", which stands for "greater than" : e.g. + {"created_at__gt": "2016-06-06T10:33:22.063176"} will result in the + "WHERE created_at > '2016-06-06T10:33:22.063176'" clause. + - "gte", which stands for "greater than or equal to" : e.g. + {"created_at__gte": "2016-06-06T10:33:22.063176"} will result in the + "WHERE created_at >= '2016-06-06T10:33:22.063176'" clause. + - "lt", which stands for "less than" : e.g. + {"created_at__lt": "2016-06-06T10:33:22.063176"} will result in the + "WHERE created_at < '2016-06-06T10:33:22.063176'" clause. + - "lte", which stands for "less than or equal to" : e.g. + {"created_at__lte": "2016-06-06T10:33:22.063176"} will result in the + "WHERE created_at <= '2016-06-06T10:33:22.063176'" clause. + - "in": e.g. {"state__in": ('SUCCEEDED', 'FAILED')} will result in the + "WHERE state IN ('SUCCEEDED', 'FAILED')" clause. + :param query: a :py:class:`sqlalchemy.orm.query.Query` instance :param model: the model class the filters should relate to :param filters: dict with the following structure {"fieldname": value} :param plain_fields: a :py:class:`sqlalchemy.orm.query.Query` instance :param join_fieldmap: a :py:class:`sqlalchemy.orm.query.Query` instance """ + soft_delete_mixin_fields = ['deleted', 'deleted_at'] + timestamp_mixin_fields = ['created_at', 'updated_at'] filters = filters or {} - plain_fields = plain_fields or () - join_fieldmap = join_fieldmap or JoinMap() - for fieldname, value in filters.items(): + # Special case for 'deleted' because it is a non-boolean flag + if 'deleted' in filters: + deleted_filter = filters.pop('deleted') + op = 'eq' if not bool(deleted_filter) else 'neq' + filters['deleted__%s' % op] = 0 + + plain_fields = tuple( + (list(plain_fields) or []) + + soft_delete_mixin_fields + + timestamp_mixin_fields) + join_fieldmap = join_fieldmap or {} + + for raw_fieldname, value in filters.items(): + fieldname, operator_ = self.__decompose_filter(raw_fieldname) if fieldname in plain_fields: query = self.__add_simple_filter( - query, model, fieldname, value) + query, model, fieldname, value, operator_) elif fieldname in join_fieldmap: - join_fieldname, join_model = join_fieldmap[fieldname] - query = self.__add_natural_join_filter( - query, join_model, join_fieldname, value) - - query = self.__add_soft_delete_mixin_filters(query, filters, model) - query = self.__add_timestamp_mixin_filters(query, filters, model) + join_field, join_model = join_fieldmap[fieldname] + query = self.__add_join_filter( + query, join_model, join_field, value, operator_) return query @@ -324,74 +328,44 @@ class Connection(api.BaseConnection): def _add_audits_filters(self, query, filters): if filters is None: - filters = [] + filters = {} - if 'uuid' in filters: - query = query.filter_by(uuid=filters['uuid']) - if 'type' in filters: - query = query.filter_by(type=filters['type']) - if 'state' in filters: - query = query.filter_by(state=filters['state']) - if 'audit_template_id' in filters: - query = query.filter_by( - audit_template_id=filters['audit_template_id']) - if 'audit_template_uuid' in filters: - query = query.join( - models.AuditTemplate, - models.Audit.audit_template_id == models.AuditTemplate.id) - query = query.filter( - models.AuditTemplate.uuid == filters['audit_template_uuid']) - if 'audit_template_name' in filters: - query = query.join( - models.AuditTemplate, - models.Audit.audit_template_id == models.AuditTemplate.id) - query = query.filter( - models.AuditTemplate.name == - filters['audit_template_name']) + plain_fields = ['uuid', 'type', 'state', 'audit_template_id'] + join_fieldmap = { + 'audit_template_uuid': ("uuid", models.AuditTemplate), + 'audit_template_name': ("name", models.AuditTemplate), + } - query = self.__add_soft_delete_mixin_filters( - query, filters, models.Audit) - query = self.__add_timestamp_mixin_filters( - query, filters, models.Audit) - - return query + return self._add_filters( + query=query, model=models.Audit, filters=filters, + plain_fields=plain_fields, join_fieldmap=join_fieldmap) def _add_action_plans_filters(self, query, filters): if filters is None: - filters = [] + filters = {} - if 'uuid' in filters: - query = query.filter_by(uuid=filters['uuid']) - if 'state' in filters: - query = query.filter_by(state=filters['state']) - if 'audit_id' in filters: - query = query.filter_by(audit_id=filters['audit_id']) - if 'audit_uuid' in filters: - query = query.join(models.Audit, - models.ActionPlan.audit_id == models.Audit.id) - query = query.filter(models.Audit.uuid == filters['audit_uuid']) + plain_fields = ['uuid', 'state', 'audit_id'] + join_fieldmap = { + 'audit_uuid': ("uuid", models.Audit), + } - query = self.__add_soft_delete_mixin_filters( - query, filters, models.ActionPlan) - query = self.__add_timestamp_mixin_filters( - query, filters, models.ActionPlan) - - return query + return self._add_filters( + query=query, model=models.ActionPlan, filters=filters, + plain_fields=plain_fields, join_fieldmap=join_fieldmap) def _add_actions_filters(self, query, filters): if filters is None: - filters = [] + filters = {} + + plain_fields = ['uuid', 'state', 'action_plan_id'] + join_fieldmap = { + 'action_plan_uuid': ("uuid", models.ActionPlan), + } + + query = self._add_filters( + query=query, model=models.Action, filters=filters, + plain_fields=plain_fields, join_fieldmap=join_fieldmap) - if 'uuid' in filters: - query = query.filter_by(uuid=filters['uuid']) - if 'action_plan_id' in filters: - query = query.filter_by(action_plan_id=filters['action_plan_id']) - if 'action_plan_uuid' in filters: - query = query.join( - models.ActionPlan, - models.Action.action_plan_id == models.ActionPlan.id) - query = query.filter( - models.ActionPlan.uuid == filters['action_plan_uuid']) if 'audit_uuid' in filters: stmt = model_query(models.ActionPlan).join( models.Audit, @@ -399,14 +373,6 @@ class Connection(api.BaseConnection): .filter_by(uuid=filters['audit_uuid']).subquery() query = query.filter_by(action_plan_id=stmt.c.id) - if 'state' in filters: - query = query.filter_by(state=filters['state']) - - query = self.__add_soft_delete_mixin_filters( - query, filters, models.Action) - query = self.__add_timestamp_mixin_filters( - query, filters, models.Action) - return query def _add_efficacy_indicators_filters(self, query, filters): @@ -717,33 +683,16 @@ class Connection(api.BaseConnection): message=_("Cannot overwrite UUID for an existing " "Audit.")) - return self._do_update_audit(audit_id, values) - - def _do_update_audit(self, audit_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Audit, session=session) - query = add_identity_filter(query, audit_id) - try: - ref = query.with_lockmode('update').one() - except exc.NoResultFound: - raise exception.AuditNotFound(audit=audit_id) - - ref.update(values) - return ref + try: + return self._update(models.Audit, audit_id, values) + except exception.ResourceNotFound: + raise exception.AuditNotFound(audit=audit_id) def soft_delete_audit(self, audit_id): - session = get_session() - with session.begin(): - query = model_query(models.Audit, session=session) - query = add_identity_filter(query, audit_id) - - try: - query.one() - except exc.NoResultFound: - raise exception.AuditNotFound(audit=audit_id) - - query.soft_delete() + try: + self._soft_delete(models.Audit, audit_id) + except exception.ResourceNotFound: + raise exception.AuditNotFound(audit=audit_id) # ### ACTIONS ### # @@ -843,7 +792,7 @@ class Connection(api.BaseConnection): # ### ACTION PLANS ### # def get_action_plan_list( - self, context, columns=None, filters=None, limit=None, + self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): query = model_query(models.ActionPlan) query = self._add_action_plans_filters(query, filters) diff --git a/watcher/tests/db/test_audit.py b/watcher/tests/db/test_audit.py index 3866381a7..06138a228 100644 --- a/watcher/tests/db/test_audit.py +++ b/watcher/tests/db/test_audit.py @@ -47,10 +47,12 @@ class TestDbAuditFilters(base.DbTestCase): audit_template_id=self.audit_template.id, id=1, uuid=None) with freezegun.freeze_time(self.FAKE_OLD_DATE): self.audit2 = utils.create_test_audit( - audit_template_id=self.audit_template.id, id=2, uuid=None) + audit_template_id=self.audit_template.id, id=2, uuid=None, + state=audit_objects.State.FAILED) with freezegun.freeze_time(self.FAKE_OLDER_DATE): self.audit3 = utils.create_test_audit( - audit_template_id=self.audit_template.id, id=3, uuid=None) + audit_template_id=self.audit_template.id, id=3, uuid=None, + state=audit_objects.State.CANCELLED) def _soft_delete_audits(self): with freezegun.freeze_time(self.FAKE_TODAY): @@ -225,6 +227,26 @@ class TestDbAuditFilters(base.DbTestCase): [self.audit1['id'], self.audit2['id']], [r.id for r in res]) + def test_get_audit_list_filter_state_in(self): + res = self.dbapi.get_audit_list( + self.context, + filters={'state__in': (audit_objects.State.FAILED, + audit_objects.State.CANCELLED)}) + + self.assertEqual( + [self.audit2['id'], self.audit3['id']], + [r.id for r in res]) + + def test_get_audit_list_filter_state_notin(self): + res = self.dbapi.get_audit_list( + self.context, + filters={'state__notin': (audit_objects.State.FAILED, + audit_objects.State.CANCELLED)}) + + self.assertEqual( + [self.audit1['id']], + [r.id for r in res]) + class DbAuditTestCase(base.DbTestCase):