diff --git a/watcher/objects/audit.py b/watcher/objects/audit.py index c0a3994b5..53d4d6146 100644 --- a/watcher/objects/audit.py +++ b/watcher/objects/audit.py @@ -52,7 +52,8 @@ import enum from watcher.common import exception from watcher.common import utils -from watcher.db import api as dbapi +from watcher.db import api as db_api +from watcher import objects from watcher.objects import base from watcher.objects import fields as wfields @@ -75,10 +76,12 @@ class AuditType(enum.Enum): @base.WatcherObjectRegistry.register class Audit(base.WatcherPersistentObject, base.WatcherObject, base.WatcherObjectDictCompat): - # Version 1.0: Initial version - VERSION = '1.0' - dbapi = dbapi.get_instance() + # Version 1.0: Initial version + # Version 1.1: Added 'goal' and 'strategy' object field + VERSION = '1.1' + + dbapi = db_api.get_instance() fields = { 'id': wfields.IntegerField(), @@ -90,10 +93,18 @@ class Audit(base.WatcherPersistentObject, base.WatcherObject, 'scope': wfields.FlexibleListOfDictField(nullable=True), 'goal_id': wfields.IntegerField(), 'strategy_id': wfields.IntegerField(nullable=True), + + 'goal': wfields.ObjectField('Goal', nullable=True), + 'strategy': wfields.ObjectField('Strategy', nullable=True), + } + + object_fields = { + 'goal': (objects.Goal, 'goal_id'), + 'strategy': (objects.Strategy, 'strategy_id'), } @base.remotable_classmethod - def get(cls, context, audit_id): + def get(cls, context, audit_id, eager=False): """Find a audit based on its id or uuid and return a Audit object. :param context: Security context. NOTE: This should only @@ -103,17 +114,18 @@ class Audit(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Audit(context) :param audit_id: the id *or* uuid of a audit. + :param eager: Load object fields if True (Default: False) :returns: a :class:`Audit` object. """ if utils.is_int_like(audit_id): - return cls.get_by_id(context, audit_id) + return cls.get_by_id(context, audit_id, eager=eager) elif utils.is_uuid_like(audit_id): - return cls.get_by_uuid(context, audit_id) + return cls.get_by_uuid(context, audit_id, eager=eager) else: raise exception.InvalidIdentity(identity=audit_id) @base.remotable_classmethod - def get_by_id(cls, context, audit_id): + def get_by_id(cls, context, audit_id, eager=False): """Find a audit based on its integer id and return a Audit object. :param context: Security context. NOTE: This should only @@ -123,14 +135,15 @@ class Audit(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Audit(context) :param audit_id: the id of a audit. + :param eager: Load object fields if True (Default: False) :returns: a :class:`Audit` object. """ - db_audit = cls.dbapi.get_audit_by_id(context, audit_id) - audit = Audit._from_db_object(cls(context), db_audit) + db_audit = cls.dbapi.get_audit_by_id(context, audit_id, eager=eager) + audit = cls._from_db_object(cls(context), db_audit, eager=eager) return audit @base.remotable_classmethod - def get_by_uuid(cls, context, uuid): + def get_by_uuid(cls, context, uuid, eager=False): """Find a audit based on uuid and return a :class:`Audit` object. :param context: Security context. NOTE: This should only @@ -140,16 +153,17 @@ class Audit(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Audit(context) :param uuid: the uuid of a audit. + :param eager: Load object fields if True (Default: False) :returns: a :class:`Audit` object. """ - db_audit = cls.dbapi.get_audit_by_uuid(context, uuid) - audit = Audit._from_db_object(cls(context), db_audit) + db_audit = cls.dbapi.get_audit_by_uuid(context, uuid, eager=eager) + audit = cls._from_db_object(cls(context), db_audit, eager=eager) return audit @base.remotable_classmethod def list(cls, context, limit=None, marker=None, filters=None, - sort_key=None, sort_dir=None): + sort_key=None, sort_dir=None, eager=False): """Return a list of Audit objects. :param context: Security context. NOTE: This should only @@ -163,6 +177,7 @@ class Audit(base.WatcherPersistentObject, base.WatcherObject, :param filters: Filters to apply. Defaults to None. :param sort_key: column to sort results by. :param sort_dir: direction to sort. "asc" or "desc". + :param eager: Load object fields if True (Default: False) :returns: a list of :class:`Audit` object. """ @@ -171,16 +186,24 @@ class Audit(base.WatcherPersistentObject, base.WatcherObject, marker=marker, filters=filters, sort_key=sort_key, - sort_dir=sort_dir) - return [cls._from_db_object(cls(context), obj) for obj in db_audits] + sort_dir=sort_dir, + eager=eager) + return [cls._from_db_object(cls(context), obj, eager=eager) + for obj in db_audits] @base.remotable def create(self): - """Create a Audit record in the DB.""" + """Create an :class:`Audit` record in the DB. + + :returns: An :class:`Audit` object. + """ values = self.obj_get_changes() db_audit = self.dbapi.create_audit(values) - self._from_db_object(self, db_audit) + # Note(v-francoise): Always load eagerly upon creation so we can send + # notifications containing information about the related relationships + self._from_db_object(self, db_audit, eager=True) + @base.remotable def destroy(self): """Delete the Audit from the DB.""" self.dbapi.destroy_audit(self.uuid) @@ -199,14 +222,15 @@ class Audit(base.WatcherPersistentObject, base.WatcherObject, self.obj_reset_changes() @base.remotable - def refresh(self): + def refresh(self, eager=False): """Loads updates for this Audit. Loads a audit with the same uuid from the database and checks for updated attributes. Updates are applied from the loaded audit column by column, if there are any updates. + :param eager: Load object fields if True (Default: False) """ - current = self.__class__.get_by_uuid(self._context, uuid=self.uuid) + current = self.get_by_uuid(self._context, uuid=self.uuid, eager=eager) self.obj_refresh(current) @base.remotable diff --git a/watcher/tests/api/v1/test_actions.py b/watcher/tests/api/v1/test_actions.py index 3fb6751b3..909367f16 100644 --- a/watcher/tests/api/v1/test_actions.py +++ b/watcher/tests/api/v1/test_actions.py @@ -52,6 +52,8 @@ class TestListAction(api_base.FunctionalTest): def setUp(self): super(TestListAction, self).setUp() + obj_utils.create_test_goal(self.context) + obj_utils.create_test_strategy(self.context) obj_utils.create_test_action_plan(self.context) def test_empty(self): diff --git a/watcher/tests/api/v1/test_actions_plans.py b/watcher/tests/api/v1/test_actions_plans.py index 404a29d5e..6410fb029 100644 --- a/watcher/tests/api/v1/test_actions_plans.py +++ b/watcher/tests/api/v1/test_actions_plans.py @@ -27,6 +27,11 @@ from watcher.tests.objects import utils as obj_utils class TestListActionPlan(api_base.FunctionalTest): + def setUp(self): + super(TestListActionPlan, self).setUp() + obj_utils.create_test_goal(self.context) + obj_utils.create_test_audit(self.context) + def test_empty(self): response = self.get_json('/action_plans') self.assertEqual([], response['action_plans']) @@ -131,10 +136,10 @@ class TestListActionPlan(api_base.FunctionalTest): def test_many_with_soft_deleted_audit_uuid(self): action_plan_list = [] audit1 = obj_utils.create_test_audit(self.context, - id=1, + id=2, uuid=utils.generate_uuid()) audit2 = obj_utils.create_test_audit(self.context, - id=2, + id=3, uuid=utils.generate_uuid()) for id_ in range(0, 2): diff --git a/watcher/tests/api/v1/test_audits.py b/watcher/tests/api/v1/test_audits.py index 70c7b94f5..e523cfddc 100644 --- a/watcher/tests/api/v1/test_audits.py +++ b/watcher/tests/api/v1/test_audits.py @@ -683,6 +683,10 @@ class TestDelete(api_base.FunctionalTest): class TestAuditPolicyEnforcement(api_base.FunctionalTest): + def setUp(self): + super(TestAuditPolicyEnforcement, self).setUp() + obj_utils.create_test_goal(self.context) + def _common_policy_check(self, rule, func, *arg, **kwarg): self.policy.set_rules({ "admin_api": "(role:admin or role:administrator)", diff --git a/watcher/tests/db/utils.py b/watcher/tests/db/utils.py index c540f4a4b..5d99a43ba 100644 --- a/watcher/tests/db/utils.py +++ b/watcher/tests/db/utils.py @@ -61,7 +61,7 @@ def create_test_audit_template(**kwargs): def get_test_audit(**kwargs): - return { + audit_data = { 'id': kwargs.get('id', 1), 'uuid': kwargs.get('uuid', '10a47dd1-4874-4298-91cf-eff046dbdb8d'), 'audit_type': kwargs.get('audit_type', 'ONESHOT'), @@ -76,6 +76,15 @@ def get_test_audit(**kwargs): 'scope': kwargs.get('scope', []), } + # ObjectField doesn't allow None nor dict, so if we want to simulate a + # non-eager object loading, the field should not be referenced at all. + if kwargs.get('goal'): + audit_data['goal'] = kwargs.get('goal') + if kwargs.get('strategy'): + audit_data['strategy'] = kwargs.get('strategy') + + return audit_data + def create_test_audit(**kwargs): """Create test audit entry in DB and return Audit DB object. diff --git a/watcher/tests/objects/test_audit.py b/watcher/tests/objects/test_audit.py index eb935570e..144646422 100644 --- a/watcher/tests/objects/test_audit.py +++ b/watcher/tests/objects/test_audit.py @@ -14,8 +14,10 @@ # under the License. import mock + from watcher.common import exception -# from watcher.common import utils as w_utils +from watcher.common import utils as w_utils +from watcher.db.sqlalchemy import api as db_api from watcher import objects from watcher.tests.db import base from watcher.tests.db import utils @@ -23,94 +25,154 @@ from watcher.tests.db import utils class TestAuditObject(base.DbTestCase): + goal_id = 2 + + goal_data = utils.get_test_goal( + id=goal_id, uuid=w_utils.generate_uuid(), name="DUMMY") + + scenarios = [ + ('non_eager', dict( + eager=False, + fake_audit=utils.get_test_audit( + goal_id=goal_id))), + ('eager_with_non_eager_load', dict( + eager=True, + fake_audit=utils.get_test_audit( + goal_id=goal_id))), + ('eager_with_eager_load', dict( + eager=True, + fake_audit=utils.get_test_audit(goal_id=goal_id, goal=goal_data))), + ] + def setUp(self): super(TestAuditObject, self).setUp() - self.fake_audit = utils.get_test_audit() + self.fake_goal = utils.create_test_goal(**self.goal_data) - def test_get_by_id(self): + def eager_load_audit_assert(self, audit, goal): + if self.eager: + self.assertIsNotNone(audit.goal) + fields_to_check = set( + super(objects.Goal, objects.Goal).fields + ).symmetric_difference(objects.Goal.fields) + db_data = { + k: v for k, v in goal.as_dict().items() + if k in fields_to_check} + object_data = { + k: v for k, v in audit.goal.as_dict().items() + if k in fields_to_check} + self.assertEqual(db_data, object_data) + + @mock.patch.object(db_api.Connection, 'get_audit_by_id') + def test_get_by_id(self, mock_get_audit): + mock_get_audit.return_value = self.fake_audit audit_id = self.fake_audit['id'] - with mock.patch.object(self.dbapi, 'get_audit_by_id', - autospec=True) as mock_get_audit: - mock_get_audit.return_value = self.fake_audit - audit = objects.Audit.get(self.context, audit_id) - mock_get_audit.assert_called_once_with(self.context, - audit_id) - self.assertEqual(self.context, audit._context) + audit = objects.Audit.get(self.context, audit_id, eager=self.eager) + mock_get_audit.assert_called_once_with( + self.context, audit_id, eager=self.eager) + self.assertEqual(self.context, audit._context) + self.eager_load_audit_assert(audit, self.fake_goal) - def test_get_by_uuid(self): + @mock.patch.object(db_api.Connection, 'get_audit_by_uuid') + def test_get_by_uuid(self, mock_get_audit): + mock_get_audit.return_value = self.fake_audit uuid = self.fake_audit['uuid'] - with mock.patch.object(self.dbapi, 'get_audit_by_uuid', - autospec=True) as mock_get_audit: - mock_get_audit.return_value = self.fake_audit - audit = objects.Audit.get(self.context, uuid) - mock_get_audit.assert_called_once_with(self.context, uuid) - self.assertEqual(self.context, audit._context) + audit = objects.Audit.get(self.context, uuid, eager=self.eager) + mock_get_audit.assert_called_once_with( + self.context, uuid, eager=self.eager) + self.assertEqual(self.context, audit._context) + self.eager_load_audit_assert(audit, self.fake_goal) def test_get_bad_id_and_uuid(self): self.assertRaises(exception.InvalidIdentity, - objects.Audit.get, self.context, 'not-a-uuid') + objects.Audit.get, self.context, + 'not-a-uuid', eager=self.eager) - def test_list(self): - with mock.patch.object(self.dbapi, 'get_audit_list', - autospec=True) as mock_get_list: - mock_get_list.return_value = [self.fake_audit] - audits = objects.Audit.list(self.context) - self.assertEqual(1, mock_get_list.call_count, 1) - self.assertEqual(1, len(audits)) - self.assertIsInstance(audits[0], objects.Audit) - self.assertEqual(self.context, audits[0]._context) + @mock.patch.object(db_api.Connection, 'get_audit_list') + def test_list(self, mock_get_list): + mock_get_list.return_value = [self.fake_audit] + audits = objects.Audit.list(self.context, eager=self.eager) + mock_get_list.assert_called_once_with( + self.context, eager=self.eager, filters=None, limit=None, + marker=None, sort_dir=None, sort_key=None) + self.assertEqual(1, len(audits)) + self.assertIsInstance(audits[0], objects.Audit) + self.assertEqual(self.context, audits[0]._context) + for audit in audits: + self.eager_load_audit_assert(audit, self.fake_goal) - def test_create(self): - with mock.patch.object(self.dbapi, 'create_audit', - autospec=True) as mock_create_audit: - mock_create_audit.return_value = self.fake_audit - audit = objects.Audit(self.context, **self.fake_audit) - - audit.create() - mock_create_audit.assert_called_once_with(self.fake_audit) - self.assertEqual(self.context, audit._context) - - def test_destroy(self): + @mock.patch.object(db_api.Connection, 'update_audit') + @mock.patch.object(db_api.Connection, 'get_audit_by_uuid') + def test_save(self, mock_get_audit, mock_update_audit): + mock_get_audit.return_value = self.fake_audit uuid = self.fake_audit['uuid'] - with mock.patch.object(self.dbapi, 'get_audit_by_uuid', - autospec=True) as mock_get_audit: - mock_get_audit.return_value = self.fake_audit - with mock.patch.object(self.dbapi, 'destroy_audit', - autospec=True) as mock_destroy_audit: - audit = objects.Audit.get_by_uuid(self.context, uuid) - audit.destroy() - mock_get_audit.assert_called_once_with(self.context, uuid) - mock_destroy_audit.assert_called_once_with(uuid) - self.assertEqual(self.context, audit._context) + audit = objects.Audit.get_by_uuid(self.context, uuid, eager=self.eager) + audit.state = 'SUCCEEDED' + audit.save() - def test_save(self): - uuid = self.fake_audit['uuid'] - with mock.patch.object(self.dbapi, 'get_audit_by_uuid', - autospec=True) as mock_get_audit: - mock_get_audit.return_value = self.fake_audit - with mock.patch.object(self.dbapi, 'update_audit', - autospec=True) as mock_update_audit: - audit = objects.Audit.get_by_uuid(self.context, uuid) - audit.state = 'SUCCEEDED' - audit.save() + mock_get_audit.assert_called_once_with( + self.context, uuid, eager=self.eager) + mock_update_audit.assert_called_once_with( + uuid, {'state': 'SUCCEEDED'}) + self.assertEqual(self.context, audit._context) + self.eager_load_audit_assert(audit, self.fake_goal) - mock_get_audit.assert_called_once_with(self.context, uuid) - mock_update_audit.assert_called_once_with( - uuid, {'state': 'SUCCEEDED'}) - self.assertEqual(self.context, audit._context) - - def test_refresh(self): - uuid = self.fake_audit['uuid'] + @mock.patch.object(db_api.Connection, 'get_audit_by_uuid') + def test_refresh(self, mock_get_audit): returns = [dict(self.fake_audit, state="first state"), dict(self.fake_audit, state="second state")] - expected = [mock.call(self.context, uuid), - mock.call(self.context, uuid)] - with mock.patch.object(self.dbapi, 'get_audit_by_uuid', - side_effect=returns, - autospec=True) as mock_get_audit: - audit = objects.Audit.get(self.context, uuid) - self.assertEqual("first state", audit.state) - audit.refresh() - self.assertEqual("second state", audit.state) - self.assertEqual(expected, mock_get_audit.call_args_list) - self.assertEqual(self.context, audit._context) + mock_get_audit.side_effect = returns + uuid = self.fake_audit['uuid'] + expected = [ + mock.call(self.context, uuid, eager=self.eager), + mock.call(self.context, uuid, eager=self.eager)] + audit = objects.Audit.get(self.context, uuid, eager=self.eager) + self.assertEqual("first state", audit.state) + audit.refresh(eager=self.eager) + self.assertEqual("second state", audit.state) + self.assertEqual(expected, mock_get_audit.call_args_list) + self.assertEqual(self.context, audit._context) + self.eager_load_audit_assert(audit, self.fake_goal) + + +class TestCreateDeleteAuditObject(base.DbTestCase): + + def setUp(self): + super(TestCreateDeleteAuditObject, self).setUp() + self.goal_id = 1 + self.fake_audit = utils.get_test_audit(goal_id=self.goal_id) + + @mock.patch.object(db_api.Connection, 'create_audit') + def test_create(self, mock_create_audit): + utils.create_test_goal(id=self.goal_id) + mock_create_audit.return_value = self.fake_audit + audit = objects.Audit(self.context, **self.fake_audit) + audit.create() + mock_create_audit.assert_called_once_with(self.fake_audit) + self.assertEqual(self.context, audit._context) + + @mock.patch.object(db_api.Connection, 'update_audit') + @mock.patch.object(db_api.Connection, 'soft_delete_audit') + @mock.patch.object(db_api.Connection, 'get_audit_by_uuid') + def test_soft_delete(self, mock_get_audit, + mock_soft_delete_audit, mock_update_audit): + mock_get_audit.return_value = self.fake_audit + uuid = self.fake_audit['uuid'] + audit = objects.Audit.get_by_uuid(self.context, uuid) + audit.soft_delete() + mock_get_audit.assert_called_once_with(self.context, uuid, eager=False) + mock_soft_delete_audit.assert_called_once_with(uuid) + mock_update_audit.assert_called_once_with(uuid, {'state': 'DELETED'}) + self.assertEqual(self.context, audit._context) + + @mock.patch.object(db_api.Connection, 'destroy_audit') + @mock.patch.object(db_api.Connection, 'get_audit_by_uuid') + def test_destroy(self, mock_get_audit, + mock_destroy_audit): + mock_get_audit.return_value = self.fake_audit + uuid = self.fake_audit['uuid'] + audit = objects.Audit.get_by_uuid(self.context, uuid) + audit.destroy() + mock_get_audit.assert_called_once_with( + self.context, uuid, eager=False) + mock_destroy_audit.assert_called_once_with(uuid) + self.assertEqual(self.context, audit._context) diff --git a/watcher/tests/objects/test_objects.py b/watcher/tests/objects/test_objects.py index 8857c6d3d..92b237a8f 100644 --- a/watcher/tests/objects/test_objects.py +++ b/watcher/tests/objects/test_objects.py @@ -412,7 +412,7 @@ expected_object_fingerprints = { 'Goal': '1.0-93881622db05e7b67a65ca885b4a022e', 'Strategy': '1.1-73f164491bdd4c034f48083a51bdeb7b', 'AuditTemplate': '1.1-b291973ffc5efa2c61b24fe34fdccc0b', - 'Audit': '1.0-ebfc5360d019baf583a10a8a27071c97', + 'Audit': '1.1-dc246337c8d511646cb537144fcb0f3a', 'ActionPlan': '1.0-cc76fd7f0e8479aeff817dd266341de4', 'Action': '1.0-a78f69c0da98e13e601f9646f6b2f883', 'EfficacyIndicator': '1.0-655b71234a82bc7478aff964639c4bb0',