diff --git a/watcher/objects/action.py b/watcher/objects/action.py index 42c41fe6e..d922d56d1 100644 --- a/watcher/objects/action.py +++ b/watcher/objects/action.py @@ -16,7 +16,8 @@ from watcher.common import exception from watcher.common import utils -from watcher.db import api as dbapi +from watcher.db import api as db_api +from watcher import objects from watcher.objects import base from watcher.objects import fields as wfields @@ -35,60 +36,69 @@ class Action(base.WatcherPersistentObject, base.WatcherObject, base.WatcherObjectDictCompat): # Version 1.0: Initial version - VERSION = '1.0' + # Version 1.1: Added 'action_plan' object field + VERSION = '1.1' - dbapi = dbapi.get_instance() + dbapi = db_api.get_instance() fields = { 'id': wfields.IntegerField(), 'uuid': wfields.UUIDField(), - 'action_plan_id': wfields.IntegerField(nullable=True), + 'action_plan_id': wfields.IntegerField(), 'action_type': wfields.StringField(nullable=True), 'input_parameters': wfields.DictField(nullable=True), 'state': wfields.StringField(nullable=True), 'next': wfields.IntegerField(nullable=True), + + 'action_plan': wfields.ObjectField('ActionPlan', nullable=True), + } + object_fields = { + 'action_plan': (objects.ActionPlan, 'action_plan_id'), } @base.remotable_classmethod - def get(cls, context, action_id): + def get(cls, context, action_id, eager=False): """Find a action based on its id or uuid and return a Action object. :param action_id: the id *or* uuid of a action. + :param eager: Load object fields if True (Default: False) :returns: a :class:`Action` object. """ if utils.is_int_like(action_id): - return cls.get_by_id(context, action_id) + return cls.get_by_id(context, action_id, eager=eager) elif utils.is_uuid_like(action_id): - return cls.get_by_uuid(context, action_id) + return cls.get_by_uuid(context, action_id, eager=eager) else: raise exception.InvalidIdentity(identity=action_id) @base.remotable_classmethod - def get_by_id(cls, context, action_id): + def get_by_id(cls, context, action_id, eager=False): """Find a action based on its integer id and return a Action object. :param action_id: the id of a action. + :param eager: Load object fields if True (Default: False) :returns: a :class:`Action` object. """ - db_action = cls.dbapi.get_action_by_id(context, action_id) - action = Action._from_db_object(cls(context), db_action) + db_action = cls.dbapi.get_action_by_id(context, action_id, eager=eager) + action = cls._from_db_object(cls(context), db_action, eager=eager) return action @base.remotable_classmethod - def get_by_uuid(cls, context, uuid): + def get_by_uuid(cls, context, uuid, eager=False): """Find a action based on uuid and return a :class:`Action` object. :param uuid: the uuid of a action. :param context: Security context + :param eager: Load object fields if True (Default: False) :returns: a :class:`Action` object. """ - db_action = cls.dbapi.get_action_by_uuid(context, uuid) - action = Action._from_db_object(cls(context), db_action) + db_action = cls.dbapi.get_action_by_uuid(context, uuid, eager=eager) + action = cls._from_db_object(cls(context), db_action, eager=eager) return action @base.remotable_classmethod def list(cls, context, limit=None, marker=None, filters=None, - sort_key=None, sort_dir=None): + sort_key=None, sort_dir=None, eager=False): """Return a list of Action objects. :param context: Security context. @@ -97,6 +107,7 @@ class Action(base.WatcherPersistentObject, base.WatcherObject, :param filters: Filters to apply. Defaults to None. :param sort_key: column to sort results by. :param sort_dir: direction to sort. "asc" or "desc". + :param eager: Load object fields if True (Default: False) :returns: a list of :class:`Action` object. """ db_actions = cls.dbapi.get_action_list(context, @@ -104,17 +115,23 @@ class Action(base.WatcherPersistentObject, base.WatcherObject, marker=marker, filters=filters, sort_key=sort_key, - sort_dir=sort_dir) + sort_dir=sort_dir, + eager=eager) - return [cls._from_db_object(cls(context), obj) + return [cls._from_db_object(cls(context), obj, eager=eager) for obj in db_actions] @base.remotable def create(self): - """Create a Action record in the DB""" + """Create an :class:`Action` record in the DB. + + :returns: An :class:`Action` object. + """ values = self.obj_get_changes() db_action = self.dbapi.create_action(values) - self._from_db_object(self, db_action) + # Note(v-francoise): Always load eagerly upon creation so we can send + # notifications containing information about the related relationships + self._from_db_object(self, db_action, eager=True) def destroy(self): """Delete the Action from the DB""" @@ -134,14 +151,15 @@ class Action(base.WatcherPersistentObject, base.WatcherObject, self.obj_reset_changes() @base.remotable - def refresh(self): + def refresh(self, eager=False): """Loads updates for this Action. Loads a action with the same uuid from the database and checks for updated attributes. Updates are applied from the loaded action column by column, if there are any updates. + :param eager: Load object fields if True (Default: False) """ - current = self.__class__.get_by_uuid(self._context, uuid=self.uuid) + current = self.get_by_uuid(self._context, uuid=self.uuid, eager=eager) self.obj_refresh(current) @base.remotable diff --git a/watcher/tests/api/v1/test_actions.py b/watcher/tests/api/v1/test_actions.py index 726a77d80..b86259c40 100644 --- a/watcher/tests/api/v1/test_actions.py +++ b/watcher/tests/api/v1/test_actions.py @@ -325,9 +325,9 @@ class TestListAction(api_base.FunctionalTest): action_list.append(action.uuid) response = self.get_json('/actions') response_actions = response['actions'] - for id in [0, 1, 2, 3]: - self.assertEqual(response_actions[id]['next_uuid'], - response_actions[id + 1]['uuid']) + for id_ in range(4): + self.assertEqual(response_actions[id_]['next_uuid'], + response_actions[id_ + 1]['uuid']) def test_many_without_soft_deleted(self): action_list = [] @@ -450,8 +450,7 @@ class TestPatch(api_base.FunctionalTest): response = self.patch_json( '/actions/%s' % self.action.uuid, - [{'path': '/state', 'value': new_state, - 'op': 'replace'}], + [{'path': '/state', 'value': new_state, 'op': 'replace'}], expect_errors=True) self.assertEqual('application/json', response.content_type) self.assertEqual(403, response.status_int) @@ -465,6 +464,7 @@ class TestDelete(api_base.FunctionalTest): self.goal = obj_utils.create_test_goal(self.context) self.strategy = obj_utils.create_test_strategy(self.context) self.audit = obj_utils.create_test_audit(self.context) + self.action_plan = obj_utils.create_test_action_plan(self.context) self.action = obj_utils.create_test_action(self.context, next=None) p = mock.patch.object(db_api.BaseConnection, 'update_action') self.mock_action_update = p.start() @@ -488,6 +488,13 @@ class TestDelete(api_base.FunctionalTest): class TestActionPolicyEnforcement(api_base.FunctionalTest): + def setUp(self): + super(TestActionPolicyEnforcement, self).setUp() + obj_utils.create_test_goal(self.context) + obj_utils.create_test_strategy(self.context) + obj_utils.create_test_audit(self.context) + obj_utils.create_test_action_plan(self.context) + def _common_policy_check(self, rule, func, *arg, **kwarg): self.policy.set_rules({ "admin_api": "(role:admin or role:administrator)", diff --git a/watcher/tests/applier/action_plan/test_default_action_handler.py b/watcher/tests/applier/action_plan/test_default_action_handler.py index 83668400c..d6219700a 100644 --- a/watcher/tests/applier/action_plan/test_default_action_handler.py +++ b/watcher/tests/applier/action_plan/test_default_action_handler.py @@ -26,6 +26,7 @@ from watcher.tests.objects import utils as obj_utils class TestDefaultActionPlanHandler(base.DbTestCase): + def setUp(self): super(TestDefaultActionPlanHandler, self).setUp() obj_utils.create_test_goal(self.context) diff --git a/watcher/tests/db/test_purge.py b/watcher/tests/db/test_purge.py index 486f71811..d69968257 100644 --- a/watcher/tests/db/test_purge.py +++ b/watcher/tests/db/test_purge.py @@ -265,13 +265,13 @@ class TestPurgeCommand(base.DbTestCase): with freezegun.freeze_time(self.fake_today): # orphan audit template audit_template4 = obj_utils.create_test_audit_template( - self.context, goal_id=self.goal2.id, # Does not exist + self.context, goal_id=self.goal2.id, name=self.generate_unique_name(prefix="Audit Template 4 "), - strategy_id=None, id=self._generate_id(), + strategy_id=self.strategy1.id, id=self._generate_id(), uuid=utils.generate_uuid()) audit4 = obj_utils.create_test_audit( self.context, audit_template_id=audit_template4.id, - id=self._generate_id(), + strategy_id=self.strategy1.id, id=self._generate_id(), uuid=utils.generate_uuid()) action_plan4 = obj_utils.create_test_action_plan( self.context, @@ -289,7 +289,7 @@ class TestPurgeCommand(base.DbTestCase): uuid=utils.generate_uuid()) audit5 = obj_utils.create_test_audit( self.context, audit_template_id=audit_template5.id, - id=self._generate_id(), + strategy_id=self.strategy1.id, id=self._generate_id(), uuid=utils.generate_uuid()) action_plan5 = obj_utils.create_test_action_plan( self.context, @@ -362,7 +362,7 @@ class TestPurgeCommand(base.DbTestCase): with freezegun.freeze_time(self.fake_today): # orphan audit template audit_template4 = obj_utils.create_test_audit_template( - self.context, goal_id=self.goal2.id, # Does not exist + self.context, goal_id=self.goal2.id, name=self.generate_unique_name(prefix="Audit Template 4 "), strategy_id=None, id=self._generate_id(), uuid=utils.generate_uuid()) @@ -386,7 +386,7 @@ class TestPurgeCommand(base.DbTestCase): uuid=utils.generate_uuid()) audit5 = obj_utils.create_test_audit( self.context, audit_template_id=audit_template5.id, - id=self._generate_id(), + strategy_id=self.strategy1.id, id=self._generate_id(), uuid=utils.generate_uuid()) action_plan5 = obj_utils.create_test_action_plan( self.context, diff --git a/watcher/tests/db/utils.py b/watcher/tests/db/utils.py index 9a05a7a7b..07a9439b1 100644 --- a/watcher/tests/db/utils.py +++ b/watcher/tests/db/utils.py @@ -102,7 +102,7 @@ def create_test_audit(**kwargs): def get_test_action(**kwargs): - return { + action_data = { 'id': kwargs.get('id', 1), 'uuid': kwargs.get('uuid', '10a47dd1-4874-4298-91cf-eff046dbdb8d'), 'action_plan_id': kwargs.get('action_plan_id', 1), @@ -120,6 +120,13 @@ def get_test_action(**kwargs): 'deleted_at': kwargs.get('deleted_at'), } + # ObjectField doesn't allow None nor dict, so if we want to simulate a + # non-eager object loading, the field should not be referenced at all. + if kwargs.get('action_plan'): + action_data['action_plan'] = kwargs.get('action_plan') + + return action_data + def create_test_action(**kwargs): """Create test action entry in DB and return Action DB object. diff --git a/watcher/tests/objects/test_action.py b/watcher/tests/objects/test_action.py index 4085904cf..1b3157e67 100644 --- a/watcher/tests/objects/test_action.py +++ b/watcher/tests/objects/test_action.py @@ -14,103 +14,164 @@ # under the License. import mock + from watcher.common import exception +from watcher.db.sqlalchemy import api as db_api from watcher import objects -from watcher.objects import action as actionobject from watcher.tests.db import base from watcher.tests.db import utils class TestActionObject(base.DbTestCase): + action_plan_id = 2 + + scenarios = [ + ('non_eager', dict( + eager=False, + fake_action=utils.get_test_action( + action_plan_id=action_plan_id))), + ('eager_with_non_eager_load', dict( + eager=True, + fake_action=utils.get_test_action( + action_plan_id=action_plan_id))), + ('eager_with_eager_load', dict( + eager=True, + fake_action=utils.get_test_action( + action_plan_id=action_plan_id, + action_plan=utils.get_test_action_plan(id=action_plan_id)))), + ] + def setUp(self): super(TestActionObject, self).setUp() - self.fake_action = utils.get_test_action() + self.fake_action_plan = utils.create_test_action_plan( + id=self.action_plan_id) - def test_get_by_id(self): + def eager_action_assert(self, action): + if self.eager: + self.assertIsNotNone(action.action_plan) + fields_to_check = set( + super(objects.ActionPlan, objects.ActionPlan).fields + ).symmetric_difference(objects.ActionPlan.fields) + db_data = { + k: v for k, v in self.fake_action_plan.as_dict().items() + if k in fields_to_check} + object_data = { + k: v for k, v in action.action_plan.as_dict().items() + if k in fields_to_check} + self.assertEqual(db_data, object_data) + + @mock.patch.object(db_api.Connection, 'get_action_by_id') + def test_get_by_id(self, mock_get_action): + mock_get_action.return_value = self.fake_action action_id = self.fake_action['id'] - with mock.patch.object(self.dbapi, 'get_action_by_id', - autospec=True) as mock_get_action: - mock_get_action.return_value = self.fake_action - action = objects.Action.get(self.context, action_id) - mock_get_action.assert_called_once_with(self.context, - action_id) - self.assertEqual(self.context, action._context) + action = objects.Action.get(self.context, action_id, eager=self.eager) + mock_get_action.assert_called_once_with( + self.context, action_id, eager=self.eager) + self.assertEqual(self.context, action._context) + self.eager_action_assert(action) - def test_get_by_uuid(self): + @mock.patch.object(db_api.Connection, 'get_action_by_uuid') + def test_get_by_uuid(self, mock_get_action): + mock_get_action.return_value = self.fake_action uuid = self.fake_action['uuid'] - with mock.patch.object(self.dbapi, 'get_action_by_uuid', - autospec=True) as mock_get_action: - mock_get_action.return_value = self.fake_action - action = objects.Action.get(self.context, uuid) - mock_get_action.assert_called_once_with(self.context, uuid) - self.assertEqual(self.context, action._context) + action = objects.Action.get(self.context, uuid, eager=self.eager) + mock_get_action.assert_called_once_with( + self.context, uuid, eager=self.eager) + self.assertEqual(self.context, action._context) def test_get_bad_id_and_uuid(self): self.assertRaises(exception.InvalidIdentity, - objects.Action.get, self.context, 'not-a-uuid') + objects.Action.get, self.context, 'not-a-uuid', + eager=self.eager) - def test_list(self): - with mock.patch.object(self.dbapi, 'get_action_list', - autospec=True) as mock_get_list: - mock_get_list.return_value = [self.fake_action] - actions = objects.Action.list(self.context) - self.assertEqual(1, mock_get_list.call_count) - self.assertEqual(1, len(actions)) - self.assertIsInstance(actions[0], objects.Action) - self.assertEqual(self.context, actions[0]._context) + @mock.patch.object(db_api.Connection, 'get_action_list') + def test_list(self, mock_get_list): + mock_get_list.return_value = [self.fake_action] + actions = objects.Action.list(self.context, eager=self.eager) + self.assertEqual(1, mock_get_list.call_count) + self.assertEqual(1, len(actions)) + self.assertIsInstance(actions[0], objects.Action) + self.assertEqual(self.context, actions[0]._context) + for action in actions: + self.eager_action_assert(action) - def test_create(self): - with mock.patch.object(self.dbapi, 'create_action', - autospec=True) as mock_create_action: - mock_create_action.return_value = self.fake_action - action = objects.Action(self.context, **self.fake_action) - - action.create() - mock_create_action.assert_called_once_with(self.fake_action) - self.assertEqual(self.context, action._context) - - def test_destroy(self): + @mock.patch.object(db_api.Connection, 'update_action') + @mock.patch.object(db_api.Connection, 'get_action_by_uuid') + def test_save(self, mock_get_action, mock_update_action): + mock_get_action.return_value = self.fake_action uuid = self.fake_action['uuid'] - with mock.patch.object(self.dbapi, 'get_action_by_uuid', - autospec=True) as mock_get_action: - mock_get_action.return_value = self.fake_action - with mock.patch.object(self.dbapi, 'destroy_action', - autospec=True) as mock_destroy_action: - action = objects.Action.get_by_uuid(self.context, uuid) - action.destroy() - mock_get_action.assert_called_once_with(self.context, uuid) - mock_destroy_action.assert_called_once_with(uuid) - self.assertEqual(self.context, action._context) + action = objects.Action.get_by_uuid( + self.context, uuid, eager=self.eager) + action.state = objects.action.State.SUCCEEDED + action.save() - def test_save(self): - uuid = self.fake_action['uuid'] - with mock.patch.object(self.dbapi, 'get_action_by_uuid', - autospec=True) as mock_get_action: - mock_get_action.return_value = self.fake_action - with mock.patch.object(self.dbapi, 'update_action', - autospec=True) as mock_update_action: - action = objects.Action.get_by_uuid(self.context, uuid) - action.state = actionobject.State.SUCCEEDED - action.save() + mock_get_action.assert_called_once_with( + self.context, uuid, eager=self.eager) + mock_update_action.assert_called_once_with( + uuid, {'state': objects.action.State.SUCCEEDED}) + self.assertEqual(self.context, action._context) - mock_get_action.assert_called_once_with(self.context, uuid) - mock_update_action.assert_called_once_with( - uuid, {'state': actionobject.State.SUCCEEDED}) - self.assertEqual(self.context, action._context) - - def test_refresh(self): - uuid = self.fake_action['uuid'] + @mock.patch.object(db_api.Connection, 'get_action_by_uuid') + def test_refresh(self, mock_get_action): returns = [dict(self.fake_action, state="first state"), dict(self.fake_action, state="second state")] - expected = [mock.call(self.context, uuid), - mock.call(self.context, uuid)] - with mock.patch.object(self.dbapi, 'get_action_by_uuid', - side_effect=returns, - autospec=True) as mock_get_action: - action = objects.Action.get(self.context, uuid) - self.assertEqual("first state", action.state) - action.refresh() - self.assertEqual("second state", action.state) - self.assertEqual(expected, mock_get_action.call_args_list) - self.assertEqual(self.context, action._context) + mock_get_action.side_effect = returns + uuid = self.fake_action['uuid'] + expected = [mock.call(self.context, uuid, eager=self.eager), + mock.call(self.context, uuid, eager=self.eager)] + action = objects.Action.get(self.context, uuid, eager=self.eager) + self.assertEqual("first state", action.state) + action.refresh(eager=self.eager) + self.assertEqual("second state", action.state) + self.assertEqual(expected, mock_get_action.call_args_list) + self.assertEqual(self.context, action._context) + self.eager_action_assert(action) + + +class TestCreateDeleteActionObject(base.DbTestCase): + + def setUp(self): + super(TestCreateDeleteActionObject, self).setUp() + self.fake_strategy = utils.create_test_strategy(name="DUMMY") + self.fake_audit = utils.create_test_audit() + self.fake_action_plan = utils.create_test_action_plan() + self.fake_action = utils.get_test_action() + + @mock.patch.object(db_api.Connection, 'create_action') + def test_create(self, mock_create_action): + mock_create_action.return_value = self.fake_action + action = objects.Action(self.context, **self.fake_action) + action.create() + + mock_create_action.assert_called_once_with(self.fake_action) + self.assertEqual(self.context, action._context) + + @mock.patch.object(db_api.Connection, 'update_action') + @mock.patch.object(db_api.Connection, 'soft_delete_action') + @mock.patch.object(db_api.Connection, 'get_action_by_uuid') + def test_soft_delete(self, mock_get_action, + mock_soft_delete_action, mock_update_action): + mock_get_action.return_value = self.fake_action + uuid = self.fake_action['uuid'] + action = objects.Action.get_by_uuid(self.context, uuid) + action.soft_delete() + mock_get_action.assert_called_once_with( + self.context, uuid, eager=False) + mock_soft_delete_action.assert_called_once_with(uuid) + mock_update_action.assert_called_once_with( + uuid, {'state': objects.action.State.DELETED}) + self.assertEqual(self.context, action._context) + + @mock.patch.object(db_api.Connection, 'destroy_action') + @mock.patch.object(db_api.Connection, 'get_action_by_uuid') + def test_destroy(self, mock_get_action, mock_destroy_action): + mock_get_action.return_value = self.fake_action + uuid = self.fake_action['uuid'] + action = objects.Action.get_by_uuid(self.context, uuid) + action.destroy() + + mock_get_action.assert_called_once_with( + self.context, uuid, eager=False) + mock_destroy_action.assert_called_once_with(uuid) + self.assertEqual(self.context, action._context) diff --git a/watcher/tests/objects/test_objects.py b/watcher/tests/objects/test_objects.py index 9a27d302d..5d1fbd1d0 100644 --- a/watcher/tests/objects/test_objects.py +++ b/watcher/tests/objects/test_objects.py @@ -414,7 +414,7 @@ expected_object_fingerprints = { 'AuditTemplate': '1.1-b291973ffc5efa2c61b24fe34fdccc0b', 'Audit': '1.1-dc246337c8d511646cb537144fcb0f3a', 'ActionPlan': '1.1-299bd9c76f2402a0b2167f8e4d744a05', - 'Action': '1.0-a78f69c0da98e13e601f9646f6b2f883', + 'Action': '1.1-52c77e4db4ce0aa9480c9760faec61a1', 'EfficacyIndicator': '1.0-655b71234a82bc7478aff964639c4bb0', 'ScoringEngine': '1.0-4abbe833544000728e17bd9e83f97576', 'Service': '1.0-4b35b99ada9677a882c9de2b30212f35',