Merge "Added filter operators"

This commit is contained in:
Jenkins
2016-07-04 07:10:50 +00:00
committed by Gerrit Code Review
4 changed files with 162 additions and 196 deletions

View File

@@ -150,6 +150,11 @@ class InvalidIdentity(Invalid):
msg_fmt = _("Expected a uuid or int but received %(identity)s")
class InvalidOperator(Invalid):
msg_fmt = _("Filter operator is not valid: %(operator)s not "
"in %(valid_operators)s")
class InvalidGoal(Invalid):
msg_fmt = _("Goal %(goal)s is invalid")

View File

@@ -139,8 +139,6 @@ class BaseConnection(object):
match the specified filters.
:param context: The security context
:param columns: List of column names to return.
Defaults to 'id' column when columns == None.
:param filters: Filters to apply. Defaults to None.
:param limit: Maximum number of strategies to return.
@@ -221,7 +219,7 @@ class BaseConnection(object):
"""
@abc.abstractmethod
def get_audit_template_list(self, context, columns=None, filters=None,
def get_audit_template_list(self, context, filters=None,
limit=None, marker=None, sort_key=None,
sort_dir=None):
"""Get specific columns for matching audit templates.
@@ -230,8 +228,6 @@ class BaseConnection(object):
match the specified filters.
:param context: The security context
:param columns: List of column names to return.
Defaults to 'id' column when columns == None.
:param filters: Filters to apply. Defaults to None.
:param limit: Maximum number of audit templates to return.
@@ -320,7 +316,7 @@ class BaseConnection(object):
"""
@abc.abstractmethod
def get_audit_list(self, context, columns=None, filters=None, limit=None,
def get_audit_list(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
"""Get specific columns for matching audits.
@@ -328,8 +324,6 @@ class BaseConnection(object):
specified filters.
:param context: The security context
:param columns: List of column names to return.
Defaults to 'id' column when columns == None.
:param filters: Filters to apply. Defaults to None.
:param limit: Maximum number of audits to return.
@@ -407,7 +401,7 @@ class BaseConnection(object):
"""
@abc.abstractmethod
def get_action_list(self, context, columns=None, filters=None, limit=None,
def get_action_list(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
"""Get specific columns for matching actions.
@@ -415,8 +409,6 @@ class BaseConnection(object):
specified filters.
:param context: The security context
:param columns: List of column names to return.
Defaults to 'id' column when columns == None.
:param filters: Filters to apply. Defaults to None.
:param limit: Maximum number of actions to return.
@@ -490,7 +482,7 @@ class BaseConnection(object):
@abc.abstractmethod
def get_action_plan_list(
self, context, columns=None, filters=None, limit=None,
self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
"""Get specific columns for matching action plans.
@@ -498,8 +490,6 @@ class BaseConnection(object):
match the specified filters.
:param context: The security context
:param columns: List of column names to return.
Defaults to 'id' column when columns == None.
:param filters: Filters to apply. Defaults to None.
:param limit: Maximum number of audits to return.

View File

@@ -14,10 +14,11 @@
# License for the specific language governing permissions and limitations
# under the License.
"""SQLAlchemy storage backend."""
import collections
import datetime
import operator
from oslo_config import cfg
from oslo_db import exception as db_exc
@@ -25,7 +26,7 @@ from oslo_db.sqlalchemy import session as db_session
from oslo_db.sqlalchemy import utils as db_utils
from sqlalchemy.orm import exc
from watcher import _i18n
from watcher._i18n import _
from watcher.common import exception
from watcher.common import utils
from watcher.db import api
@@ -36,7 +37,6 @@ from watcher.objects import audit as audit_objects
from watcher.objects import utils as objutils
CONF = cfg.CONF
_ = _i18n._
_FACADE = None
@@ -115,117 +115,121 @@ NaturalJoinFilter = collections.namedtuple(
class Connection(api.BaseConnection):
"""SqlAlchemy connection."""
valid_operators = {
"": operator.eq,
"eq": operator.eq,
"neq": operator.ne,
"gt": operator.gt,
"gte": operator.ge,
"lt": operator.lt,
"lte": operator.le,
"in": lambda field, choices: field.in_(choices),
"notin": lambda field, choices: field.notin_(choices),
}
def __init__(self):
super(Connection, self).__init__()
def __add_soft_delete_mixin_filters(self, query, filters, model):
if 'deleted' in filters:
if bool(filters['deleted']):
query = query.filter(model.deleted != 0)
else:
query = query.filter(model.deleted == 0)
if 'deleted_at__eq' in filters:
query = query.filter(
model.deleted_at == objutils.datetime_or_str_or_none(
filters['deleted_at__eq']))
if 'deleted_at__gt' in filters:
query = query.filter(
model.deleted_at > objutils.datetime_or_str_or_none(
filters['deleted_at__gt']))
if 'deleted_at__gte' in filters:
query = query.filter(
model.deleted_at >= objutils.datetime_or_str_or_none(
filters['deleted_at__gte']))
if 'deleted_at__lt' in filters:
query = query.filter(
model.deleted_at < objutils.datetime_or_str_or_none(
filters['deleted_at__lt']))
if 'deleted_at__lte' in filters:
query = query.filter(
model.deleted_at <= objutils.datetime_or_str_or_none(
filters['deleted_at__lte']))
def __add_simple_filter(self, query, model, fieldname, value, operator_):
field = getattr(model, fieldname)
return query
if field.type.python_type is datetime.datetime:
value = objutils.datetime_or_str_or_none(value)
def __add_timestamp_mixin_filters(self, query, filters, model):
if 'created_at__eq' in filters:
query = query.filter(
model.created_at == objutils.datetime_or_str_or_none(
filters['created_at__eq']))
if 'created_at__gt' in filters:
query = query.filter(
model.created_at > objutils.datetime_or_str_or_none(
filters['created_at__gt']))
if 'created_at__gte' in filters:
query = query.filter(
model.created_at >= objutils.datetime_or_str_or_none(
filters['created_at__gte']))
if 'created_at__lt' in filters:
query = query.filter(
model.created_at < objutils.datetime_or_str_or_none(
filters['created_at__lt']))
if 'created_at__lte' in filters:
query = query.filter(
model.created_at <= objutils.datetime_or_str_or_none(
filters['created_at__lte']))
return query.filter(self.valid_operators[operator_](field, value))
if 'updated_at__eq' in filters:
query = query.filter(
model.updated_at == objutils.datetime_or_str_or_none(
filters['updated_at__eq']))
if 'updated_at__gt' in filters:
query = query.filter(
model.updated_at > objutils.datetime_or_str_or_none(
filters['updated_at__gt']))
if 'updated_at__gte' in filters:
query = query.filter(
model.updated_at >= objutils.datetime_or_str_or_none(
filters['updated_at__gte']))
if 'updated_at__lt' in filters:
query = query.filter(
model.updated_at < objutils.datetime_or_str_or_none(
filters['updated_at__lt']))
if 'updated_at__lte' in filters:
query = query.filter(
model.updated_at <= objutils.datetime_or_str_or_none(
filters['updated_at__lte']))
def __add_join_filter(self, query, model, fieldname, value, operator_):
query = query.join(model)
return self.__add_simple_filter(query, model, fieldname,
value, operator_)
return query
def __decompose_filter(self, raw_fieldname):
"""Decompose a filter name into its 2 subparts
def __add_simple_filter(self, query, model, fieldname, value):
return query.filter(getattr(model, fieldname) == value)
A filter can take 2 forms:
def __add_natural_join_filter(self, query, join_model,
join_fieldname, value):
query = query.join(join_model)
return self.__add_simple_filter(
query, join_model, join_fieldname, value)
- "<FIELDNAME>" which is a syntactic sugar for "<FIELDNAME>__eq"
- "<FIELDNAME>__<OPERATOR>" where <OPERATOR> is the comparison operator
to be used.
Available operators are:
- eq
- neq
- gt
- gte
- lt
- lte
- in
- notin
"""
separator = '__'
fieldname, separator, operator_ = raw_fieldname.partition(separator)
if operator_ and operator_ not in self.valid_operators:
raise exception.InvalidOperator(
operator=operator_, valid_operators=self.valid_operators)
return fieldname, operator_
def _add_filters(self, query, model, filters=None,
plain_fields=None, join_fieldmap=None):
"""Generic way to add filters to a Watcher model
Each filter key provided by the `filters` parameter will be decomposed
into 2 pieces: the field name and the comparison operator
- "": By default, the "eq" is applied if no operator is provided
- "eq", which stands for "equal" : e.g. {"state__eq": "PENDING"}
will result in the "WHERE state = 'PENDING'" clause.
- "neq", which stands for "not equal" : e.g. {"state__neq": "PENDING"}
will result in the "WHERE state != 'PENDING'" clause.
- "gt", which stands for "greater than" : e.g.
{"created_at__gt": "2016-06-06T10:33:22.063176"} will result in the
"WHERE created_at > '2016-06-06T10:33:22.063176'" clause.
- "gte", which stands for "greater than or equal to" : e.g.
{"created_at__gte": "2016-06-06T10:33:22.063176"} will result in the
"WHERE created_at >= '2016-06-06T10:33:22.063176'" clause.
- "lt", which stands for "less than" : e.g.
{"created_at__lt": "2016-06-06T10:33:22.063176"} will result in the
"WHERE created_at < '2016-06-06T10:33:22.063176'" clause.
- "lte", which stands for "less than or equal to" : e.g.
{"created_at__lte": "2016-06-06T10:33:22.063176"} will result in the
"WHERE created_at <= '2016-06-06T10:33:22.063176'" clause.
- "in": e.g. {"state__in": ('SUCCEEDED', 'FAILED')} will result in the
"WHERE state IN ('SUCCEEDED', 'FAILED')" clause.
:param query: a :py:class:`sqlalchemy.orm.query.Query` instance
:param model: the model class the filters should relate to
:param filters: dict with the following structure {"fieldname": value}
:param plain_fields: a :py:class:`sqlalchemy.orm.query.Query` instance
:param join_fieldmap: a :py:class:`sqlalchemy.orm.query.Query` instance
"""
soft_delete_mixin_fields = ['deleted', 'deleted_at']
timestamp_mixin_fields = ['created_at', 'updated_at']
filters = filters or {}
plain_fields = plain_fields or ()
join_fieldmap = join_fieldmap or JoinMap()
for fieldname, value in filters.items():
# Special case for 'deleted' because it is a non-boolean flag
if 'deleted' in filters:
deleted_filter = filters.pop('deleted')
op = 'eq' if not bool(deleted_filter) else 'neq'
filters['deleted__%s' % op] = 0
plain_fields = tuple(
(list(plain_fields) or []) +
soft_delete_mixin_fields +
timestamp_mixin_fields)
join_fieldmap = join_fieldmap or {}
for raw_fieldname, value in filters.items():
fieldname, operator_ = self.__decompose_filter(raw_fieldname)
if fieldname in plain_fields:
query = self.__add_simple_filter(
query, model, fieldname, value)
query, model, fieldname, value, operator_)
elif fieldname in join_fieldmap:
join_fieldname, join_model = join_fieldmap[fieldname]
query = self.__add_natural_join_filter(
query, join_model, join_fieldname, value)
query = self.__add_soft_delete_mixin_filters(query, filters, model)
query = self.__add_timestamp_mixin_filters(query, filters, model)
join_field, join_model = join_fieldmap[fieldname]
query = self.__add_join_filter(
query, join_model, join_field, value, operator_)
return query
@@ -324,74 +328,44 @@ class Connection(api.BaseConnection):
def _add_audits_filters(self, query, filters):
if filters is None:
filters = []
filters = {}
if 'uuid' in filters:
query = query.filter_by(uuid=filters['uuid'])
if 'type' in filters:
query = query.filter_by(type=filters['type'])
if 'state' in filters:
query = query.filter_by(state=filters['state'])
if 'audit_template_id' in filters:
query = query.filter_by(
audit_template_id=filters['audit_template_id'])
if 'audit_template_uuid' in filters:
query = query.join(
models.AuditTemplate,
models.Audit.audit_template_id == models.AuditTemplate.id)
query = query.filter(
models.AuditTemplate.uuid == filters['audit_template_uuid'])
if 'audit_template_name' in filters:
query = query.join(
models.AuditTemplate,
models.Audit.audit_template_id == models.AuditTemplate.id)
query = query.filter(
models.AuditTemplate.name ==
filters['audit_template_name'])
plain_fields = ['uuid', 'type', 'state', 'audit_template_id']
join_fieldmap = {
'audit_template_uuid': ("uuid", models.AuditTemplate),
'audit_template_name': ("name", models.AuditTemplate),
}
query = self.__add_soft_delete_mixin_filters(
query, filters, models.Audit)
query = self.__add_timestamp_mixin_filters(
query, filters, models.Audit)
return query
return self._add_filters(
query=query, model=models.Audit, filters=filters,
plain_fields=plain_fields, join_fieldmap=join_fieldmap)
def _add_action_plans_filters(self, query, filters):
if filters is None:
filters = []
filters = {}
if 'uuid' in filters:
query = query.filter_by(uuid=filters['uuid'])
if 'state' in filters:
query = query.filter_by(state=filters['state'])
if 'audit_id' in filters:
query = query.filter_by(audit_id=filters['audit_id'])
if 'audit_uuid' in filters:
query = query.join(models.Audit,
models.ActionPlan.audit_id == models.Audit.id)
query = query.filter(models.Audit.uuid == filters['audit_uuid'])
plain_fields = ['uuid', 'state', 'audit_id']
join_fieldmap = {
'audit_uuid': ("uuid", models.Audit),
}
query = self.__add_soft_delete_mixin_filters(
query, filters, models.ActionPlan)
query = self.__add_timestamp_mixin_filters(
query, filters, models.ActionPlan)
return query
return self._add_filters(
query=query, model=models.ActionPlan, filters=filters,
plain_fields=plain_fields, join_fieldmap=join_fieldmap)
def _add_actions_filters(self, query, filters):
if filters is None:
filters = []
filters = {}
plain_fields = ['uuid', 'state', 'action_plan_id']
join_fieldmap = {
'action_plan_uuid': ("uuid", models.ActionPlan),
}
query = self._add_filters(
query=query, model=models.Action, filters=filters,
plain_fields=plain_fields, join_fieldmap=join_fieldmap)
if 'uuid' in filters:
query = query.filter_by(uuid=filters['uuid'])
if 'action_plan_id' in filters:
query = query.filter_by(action_plan_id=filters['action_plan_id'])
if 'action_plan_uuid' in filters:
query = query.join(
models.ActionPlan,
models.Action.action_plan_id == models.ActionPlan.id)
query = query.filter(
models.ActionPlan.uuid == filters['action_plan_uuid'])
if 'audit_uuid' in filters:
stmt = model_query(models.ActionPlan).join(
models.Audit,
@@ -399,14 +373,6 @@ class Connection(api.BaseConnection):
.filter_by(uuid=filters['audit_uuid']).subquery()
query = query.filter_by(action_plan_id=stmt.c.id)
if 'state' in filters:
query = query.filter_by(state=filters['state'])
query = self.__add_soft_delete_mixin_filters(
query, filters, models.Action)
query = self.__add_timestamp_mixin_filters(
query, filters, models.Action)
return query
def _add_efficacy_indicators_filters(self, query, filters):
@@ -717,33 +683,16 @@ class Connection(api.BaseConnection):
message=_("Cannot overwrite UUID for an existing "
"Audit."))
return self._do_update_audit(audit_id, values)
def _do_update_audit(self, audit_id, values):
session = get_session()
with session.begin():
query = model_query(models.Audit, session=session)
query = add_identity_filter(query, audit_id)
try:
ref = query.with_lockmode('update').one()
except exc.NoResultFound:
raise exception.AuditNotFound(audit=audit_id)
ref.update(values)
return ref
try:
return self._update(models.Audit, audit_id, values)
except exception.ResourceNotFound:
raise exception.AuditNotFound(audit=audit_id)
def soft_delete_audit(self, audit_id):
session = get_session()
with session.begin():
query = model_query(models.Audit, session=session)
query = add_identity_filter(query, audit_id)
try:
query.one()
except exc.NoResultFound:
raise exception.AuditNotFound(audit=audit_id)
query.soft_delete()
try:
self._soft_delete(models.Audit, audit_id)
except exception.ResourceNotFound:
raise exception.AuditNotFound(audit=audit_id)
# ### ACTIONS ### #
@@ -843,7 +792,7 @@ class Connection(api.BaseConnection):
# ### ACTION PLANS ### #
def get_action_plan_list(
self, context, columns=None, filters=None, limit=None,
self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
query = model_query(models.ActionPlan)
query = self._add_action_plans_filters(query, filters)

View File

@@ -47,10 +47,12 @@ class TestDbAuditFilters(base.DbTestCase):
audit_template_id=self.audit_template.id, id=1, uuid=None)
with freezegun.freeze_time(self.FAKE_OLD_DATE):
self.audit2 = utils.create_test_audit(
audit_template_id=self.audit_template.id, id=2, uuid=None)
audit_template_id=self.audit_template.id, id=2, uuid=None,
state=audit_objects.State.FAILED)
with freezegun.freeze_time(self.FAKE_OLDER_DATE):
self.audit3 = utils.create_test_audit(
audit_template_id=self.audit_template.id, id=3, uuid=None)
audit_template_id=self.audit_template.id, id=3, uuid=None,
state=audit_objects.State.CANCELLED)
def _soft_delete_audits(self):
with freezegun.freeze_time(self.FAKE_TODAY):
@@ -225,6 +227,26 @@ class TestDbAuditFilters(base.DbTestCase):
[self.audit1['id'], self.audit2['id']],
[r.id for r in res])
def test_get_audit_list_filter_state_in(self):
res = self.dbapi.get_audit_list(
self.context,
filters={'state__in': (audit_objects.State.FAILED,
audit_objects.State.CANCELLED)})
self.assertEqual(
[self.audit2['id'], self.audit3['id']],
[r.id for r in res])
def test_get_audit_list_filter_state_notin(self):
res = self.dbapi.get_audit_list(
self.context,
filters={'state__notin': (audit_objects.State.FAILED,
audit_objects.State.CANCELLED)})
self.assertEqual(
[self.audit1['id']],
[r.id for r in res])
class DbAuditTestCase(base.DbTestCase):