From f54aca70cc287826897115f8f23565cf9ac3b6cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vincent=20Fran=C3=A7oise?= Date: Fri, 23 Sep 2016 16:46:41 +0200 Subject: [PATCH] Added 'goal' ObjectField for Strategy object In this changeset, I added the "goal" ObjectField which can either be loaded by setting the new "eager" parameter as True or not loaded (as before) by setting it to False. The advantage of introducing this eager parameter is that this way, we can reduce to a minimum the overhead of DB queries whenever the related goal is not actually needed. Partially-Implements: blueprint watcher-versioned-objects Change-Id: I103c9ed161d2cedf7b43c55f9e095ef66bf44dea --- watcher/objects/base.py | 41 +++- watcher/objects/fields.py | 9 +- watcher/objects/strategy.py | 69 ++++-- watcher/tests/api/v1/test_strategies.py | 24 +- watcher/tests/db/utils.py | 11 +- .../planner/test_default_planner.py | 17 +- watcher/tests/objects/test_objects.py | 2 +- watcher/tests/objects/test_strategy.py | 206 +++++++++++------- watcher/tests/objects/utils.py | 8 + 9 files changed, 255 insertions(+), 132 deletions(-) diff --git a/watcher/objects/base.py b/watcher/objects/base.py index d23e3fd48..97d3bed08 100644 --- a/watcher/objects/base.py +++ b/watcher/objects/base.py @@ -103,6 +103,19 @@ class WatcherPersistentObject(object): 'deleted_at': ovo_fields.DateTimeField(nullable=True), } + # Mapping between the object field name and a 2-tuple pair composed of + # its object type (e.g. objects.RelatedObject) and the the name of the + # model field related ID (or UUID) foreign key field. + # e.g.: + # + # fields = { + # # [...] + # 'related_object_id': fields.IntegerField(), # Foreign key + # 'related_object': wfields.ObjectField('RelatedObject'), + # } + # {'related_object': (objects.RelatedObject, 'related_object_id')} + object_fields = {} + def obj_refresh(self, loaded_object): """Applies updates for objects that inherit from base.WatcherObject. @@ -116,17 +129,39 @@ class WatcherPersistentObject(object): self[field] = loaded_object[field] @staticmethod - def _from_db_object(obj, db_object): + def _from_db_object(obj, db_object, eager=False): """Converts a database entity to a formal object. :param obj: An object of the class. :param db_object: A DB model of the object + :param eager: Enable the loading of object fields (Default: False) :return: The object of the class with the database entity added - """ + obj_class = type(obj) + object_fields = obj_class.object_fields for field in obj.fields: - obj[field] = db_object[field] + if field not in object_fields: + obj[field] = db_object[field] + + if eager: + # Load object fields + context = obj._context + loadable_fields = ( + (obj_field, related_obj_cls, rel_id) + for obj_field, (related_obj_cls, rel_id) + in object_fields.items() + if obj[rel_id] + ) + for obj_field, related_obj_cls, rel_id in loadable_fields: + if db_object.get(obj_field) and obj[rel_id]: + # The object field data was eagerly loaded alongside + # the main object data + obj[obj_field] = related_obj_cls._from_db_object( + related_obj_cls(context), db_object[obj_field]) + else: + # The object field data wasn't loaded yet + obj[obj_field] = related_obj_cls.get(context, obj[rel_id]) obj.obj_reset_changes() return obj diff --git a/watcher/objects/fields.py b/watcher/objects/fields.py index c1fe2061f..7b33bfae7 100644 --- a/watcher/objects/fields.py +++ b/watcher/objects/fields.py @@ -24,12 +24,13 @@ from oslo_versionedobjects import fields LOG = log.getLogger(__name__) -IntegerField = fields.IntegerField -UUIDField = fields.UUIDField -StringField = fields.StringField -DateTimeField = fields.DateTimeField BooleanField = fields.BooleanField +DateTimeField = fields.DateTimeField +IntegerField = fields.IntegerField ListOfStringsField = fields.ListOfStringsField +ObjectField = fields.ObjectField +StringField = fields.StringField +UUIDField = fields.UUIDField class Numeric(fields.FieldType): diff --git a/watcher/objects/strategy.py b/watcher/objects/strategy.py index 6f95a103e..584c8ff23 100644 --- a/watcher/objects/strategy.py +++ b/watcher/objects/strategy.py @@ -17,6 +17,7 @@ from watcher.common import exception from watcher.common import utils from watcher.db import api as db_api +from watcher import objects from watcher.objects import base from watcher.objects import fields as wfields @@ -25,6 +26,10 @@ from watcher.objects import fields as wfields class Strategy(base.WatcherPersistentObject, base.WatcherObject, base.WatcherObjectDictCompat): + # Version 1.0: Initial version + # Version 1.1: Added Goal object field + VERSION = '1.1' + dbapi = db_api.get_instance() fields = { @@ -34,10 +39,13 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, 'display_name': wfields.StringField(), 'goal_id': wfields.IntegerField(), 'parameters_spec': wfields.FlexibleDictField(nullable=True), + 'goal': wfields.ObjectField('Goal', nullable=True), } + object_fields = {'goal': (objects.Goal, 'goal_id')} + @base.remotable_classmethod - def get(cls, context, strategy_id): + def get(cls, context, strategy_id, eager=False): """Find a strategy based on its id or uuid :param context: Security context. NOTE: This should only @@ -47,17 +55,18 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Strategy(context) :param strategy_id: the id *or* uuid of a strategy. - :returns: a :class:`Strategy` object. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ if utils.is_int_like(strategy_id): - return cls.get_by_id(context, strategy_id) + return cls.get_by_id(context, strategy_id, eager=eager) elif utils.is_uuid_like(strategy_id): - return cls.get_by_uuid(context, strategy_id) + return cls.get_by_uuid(context, strategy_id, eager=eager) else: raise exception.InvalidIdentity(identity=strategy_id) @base.remotable_classmethod - def get_by_id(cls, context, strategy_id): + def get_by_id(cls, context, strategy_id, eager=False): """Find a strategy based on its integer id :param context: Security context. NOTE: This should only @@ -67,14 +76,16 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Strategy(context) :param strategy_id: the id of a strategy. - :returns: a :class:`Strategy` object. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ - db_strategy = cls.dbapi.get_strategy_by_id(context, strategy_id) - strategy = Strategy._from_db_object(cls(context), db_strategy) + db_strategy = cls.dbapi.get_strategy_by_id( + context, strategy_id, eager=eager) + strategy = cls._from_db_object(cls(context), db_strategy, eager=eager) return strategy @base.remotable_classmethod - def get_by_uuid(cls, context, uuid): + def get_by_uuid(cls, context, uuid, eager=False): """Find a strategy based on uuid :param context: Security context. NOTE: This should only @@ -84,29 +95,33 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Strategy(context) :param uuid: the uuid of a strategy. - :returns: a :class:`Strategy` object. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ - db_strategy = cls.dbapi.get_strategy_by_uuid(context, uuid) - strategy = cls._from_db_object(cls(context), db_strategy) + db_strategy = cls.dbapi.get_strategy_by_uuid( + context, uuid, eager=eager) + strategy = cls._from_db_object(cls(context), db_strategy, eager=eager) return strategy @base.remotable_classmethod - def get_by_name(cls, context, name): + def get_by_name(cls, context, name, eager=False): """Find a strategy based on name - :param name: the name of a strategy. :param context: Security context - :returns: a :class:`Strategy` object. + :param name: the name of a strategy. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ - db_strategy = cls.dbapi.get_strategy_by_name(context, name) - strategy = cls._from_db_object(cls(context), db_strategy) + db_strategy = cls.dbapi.get_strategy_by_name( + context, name, eager=eager) + strategy = cls._from_db_object(cls(context), db_strategy, eager=eager) return strategy @base.remotable_classmethod def list(cls, context, limit=None, marker=None, filters=None, - sort_key=None, sort_dir=None): + sort_key=None, sort_dir=None, eager=False): """Return a list of :class:`Strategy` objects. :param context: Security context. NOTE: This should only @@ -115,11 +130,12 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, argument, even though we don't use it. A context should be set when instantiating the object, e.g.: Strategy(context) - :param filters: dict mapping the filter key to a value. :param limit: maximum number of resources to return in a single result. :param marker: pagination marker for large data sets. + :param filters: dict mapping the filter key to a value. :param sort_key: column to sort results by. - :param sort_dir: direction to sort. "asc" or "desc". + :param sort_dir: direction to sort. "asc" or "desc`". + :param eager: Load object fields if True (Default: False) :returns: a list of :class:`Strategy` object. """ db_strategies = cls.dbapi.get_strategy_list( @@ -130,7 +146,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, sort_key=sort_key, sort_dir=sort_dir) - return [cls._from_db_object(cls(context), obj) + return [cls._from_db_object(cls(context), obj, eager=eager) for obj in db_strategies] @base.remotable @@ -143,11 +159,14 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, argument, even though we don't use it. A context should be set when instantiating the object, e.g.: Strategy(context) + :returns: A :class:`Strategy` object. """ values = self.obj_get_changes() db_strategy = self.dbapi.create_strategy(values) - self._from_db_object(self, db_strategy) + # Note(v-francoise): Always load eagerly upon creation so we can send + # notifications containing information about the related relationships + self._from_db_object(self, db_strategy, eager=True) def destroy(self, context=None): """Delete the :class:`Strategy` from the DB. @@ -182,7 +201,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, self.obj_reset_changes() @base.remotable - def refresh(self, context=None): + def refresh(self, context=None, eager=False): """Loads updates for this :class:`Strategy`. Loads a strategy with the same uuid from the database and @@ -195,8 +214,10 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, argument, even though we don't use it. A context should be set when instantiating the object, e.g.: Strategy(context) + :param eager: Load object fields if True (Default: False) """ - current = self.__class__.get_by_id(self._context, strategy_id=self.id) + current = self.__class__.get_by_id( + self._context, strategy_id=self.id, eager=eager) for field in self.fields: if (hasattr(self, base.get_attrname(field)) and self[field] != current[field]): diff --git a/watcher/tests/api/v1/test_strategies.py b/watcher/tests/api/v1/test_strategies.py index 34121a800..5be3fd448 100644 --- a/watcher/tests/api/v1/test_strategies.py +++ b/watcher/tests/api/v1/test_strategies.py @@ -21,6 +21,11 @@ from watcher.tests.objects import utils as obj_utils class TestListStrategy(api_base.FunctionalTest): + def setUp(self): + super(TestListStrategy, self).setUp() + self.fake_goal = obj_utils.create_test_goal( + self.context, uuid=utils.generate_uuid()) + def _assert_strategy_fields(self, strategy): strategy_fields = ['uuid', 'name', 'display_name', 'goal_uuid'] for field in strategy_fields: @@ -61,7 +66,6 @@ class TestListStrategy(api_base.FunctionalTest): self.assertEqual(404, response.status_int) def test_detail(self): - obj_utils.create_test_goal(self.context) strategy = obj_utils.create_test_strategy(self.context) response = self.get_json('/strategies/detail') self.assertEqual(strategy.uuid, response['strategies'][0]["uuid"]) @@ -78,7 +82,6 @@ class TestListStrategy(api_base.FunctionalTest): self.assertEqual(404, response.status_int) def test_many(self): - obj_utils.create_test_goal(self.context) strategy_list = [] for idx in range(1, 6): strategy = obj_utils.create_test_strategy( @@ -132,12 +135,12 @@ class TestListStrategy(api_base.FunctionalTest): def test_filter_by_goal_uuid(self): goal1 = obj_utils.create_test_goal( self.context, - id=1, + id=2, uuid=utils.generate_uuid(), name='My_Goal 1') goal2 = obj_utils.create_test_goal( self.context, - id=2, + id=3, uuid=utils.generate_uuid(), name='My Goal 2') @@ -164,12 +167,12 @@ class TestListStrategy(api_base.FunctionalTest): def test_filter_by_goal_name(self): goal1 = obj_utils.create_test_goal( self.context, - id=1, + id=2, uuid=utils.generate_uuid(), name='My_Goal 1') goal2 = obj_utils.create_test_goal( self.context, - id=2, + id=3, uuid=utils.generate_uuid(), name='My Goal 2') @@ -196,6 +199,11 @@ class TestListStrategy(api_base.FunctionalTest): class TestStrategyPolicyEnforcement(api_base.FunctionalTest): + def setUp(self): + super(TestStrategyPolicyEnforcement, self).setUp() + self.fake_goal = obj_utils.create_test_goal( + self.context, uuid=utils.generate_uuid()) + def _common_policy_check(self, rule, func, *arg, **kwarg): self.policy.set_rules({ "admin_api": "(role:admin or role:administrator)", @@ -227,8 +235,8 @@ class TestStrategyPolicyEnforcement(api_base.FunctionalTest): expect_errors=True) -class TestStrategyEnforcementWithAdminContext(TestListStrategy, - api_base.AdminRoleTest): +class TestStrategyEnforcementWithAdminContext( + TestListStrategy, api_base.AdminRoleTest): def setUp(self): super(TestStrategyEnforcementWithAdminContext, self).setUp() diff --git a/watcher/tests/db/utils.py b/watcher/tests/db/utils.py index ca5ef2dd9..ad91aa9cd 100644 --- a/watcher/tests/db/utils.py +++ b/watcher/tests/db/utils.py @@ -199,7 +199,7 @@ def create_test_scoring_engine(**kwargs): def get_test_strategy(**kwargs): - return { + strategy_data = { 'id': kwargs.get('id', 1), 'uuid': kwargs.get('uuid', 'cb3d0b58-4415-4d90-b75b-1e96878730e3'), 'name': kwargs.get('name', 'TEST'), @@ -208,9 +208,16 @@ def get_test_strategy(**kwargs): 'created_at': kwargs.get('created_at'), 'updated_at': kwargs.get('updated_at'), 'deleted_at': kwargs.get('deleted_at'), - 'parameters_spec': kwargs.get('parameters_spec', {}) + 'parameters_spec': kwargs.get('parameters_spec', {}), } + # goal ObjectField doesn't allow None nor dict, so if we want to simulate a + # non-eager object loading, the field should not be referenced at all. + if kwargs.get('goal'): + strategy_data['goal'] = kwargs.get('goal') + + return strategy_data + def get_test_service(**kwargs): return { diff --git a/watcher/tests/decision_engine/planner/test_default_planner.py b/watcher/tests/decision_engine/planner/test_default_planner.py index 30c7da5b9..34656f1ed 100644 --- a/watcher/tests/decision_engine/planner/test_default_planner.py +++ b/watcher/tests/decision_engine/planner/test_default_planner.py @@ -159,8 +159,11 @@ class TestDefaultPlanner(base.DbTestCase): 'migrate': 3 } - obj_utils.create_test_audit_template(self.context) - self.strategy = obj_utils.create_test_strategy(self.context) + self.goal = obj_utils.create_test_goal(self.context) + self.strategy = obj_utils.create_test_strategy( + self.context, goal_id=self.goal.id) + obj_utils.create_test_audit_template( + self.context, goal_id=self.goal.id, strategy_id=self.strategy.id) p = mock.patch.object(db_api.BaseConnection, 'create_action_plan') self.mock_create_action_plan = p.start() @@ -185,7 +188,8 @@ class TestDefaultPlanner(base.DbTestCase): @mock.patch.object(objects.Strategy, 'get_by_name') def test_schedule_scheduled_empty(self, m_get_by_name): m_get_by_name.return_value = self.strategy - audit = db_utils.create_test_audit(uuid=utils.generate_uuid()) + audit = db_utils.create_test_audit( + goal_id=self.goal.id, strategy_id=self.strategy.id) fake_solution = SolutionFakerSingleHyp.build() action_plan = self.default_planner.schedule(self.context, audit.id, fake_solution) @@ -194,8 +198,9 @@ class TestDefaultPlanner(base.DbTestCase): @mock.patch.object(objects.Strategy, 'get_by_name') def test_scheduler_warning_empty_action_plan(self, m_get_by_name): m_get_by_name.return_value = self.strategy - audit = db_utils.create_test_audit(uuid=utils.generate_uuid()) + audit = db_utils.create_test_audit( + goal_id=self.goal.id, strategy_id=self.strategy.id) fake_solution = SolutionFaker.build() - action_plan = self.default_planner.schedule(self.context, - audit.id, fake_solution) + action_plan = self.default_planner.schedule( + self.context, audit.id, fake_solution) self.assertIsNotNone(action_plan.uuid) diff --git a/watcher/tests/objects/test_objects.py b/watcher/tests/objects/test_objects.py index 165b71bdc..d71445875 100644 --- a/watcher/tests/objects/test_objects.py +++ b/watcher/tests/objects/test_objects.py @@ -410,7 +410,7 @@ class TestObject(_LocalTest, _TestObject): # The fingerprint values should only be changed if there is a version bump. expected_object_fingerprints = { 'Goal': '1.0-93881622db05e7b67a65ca885b4a022e', - 'Strategy': '1.0-e60f62cc854c6e63fb1c3befbfc8629e', + 'Strategy': '1.1-73f164491bdd4c034f48083a51bdeb7b', 'AuditTemplate': '1.0-7432ee4d3ce0c7cbb9d11a4565ee8eb6', 'Audit': '1.0-ebfc5360d019baf583a10a8a27071c97', 'ActionPlan': '1.0-cc76fd7f0e8479aeff817dd266341de4', diff --git a/watcher/tests/objects/test_strategy.py b/watcher/tests/objects/test_strategy.py index dc315ccbe..3d1a47dc5 100644 --- a/watcher/tests/objects/test_strategy.py +++ b/watcher/tests/objects/test_strategy.py @@ -15,7 +15,9 @@ import mock + from watcher.common import exception +from watcher.db.sqlalchemy import api as db_api from watcher import objects from watcher.tests.db import base from watcher.tests.db import utils @@ -23,107 +25,143 @@ from watcher.tests.db import utils class TestStrategyObject(base.DbTestCase): + goal_id = 2 + + scenarios = [ + ('non_eager', dict( + eager=False, fake_strategy=utils.get_test_strategy( + goal_id=goal_id))), + ('eager_with_non_eager_load', dict( + eager=True, fake_strategy=utils.get_test_strategy( + goal_id=goal_id))), + ('eager_with_eager_load', dict( + eager=True, fake_strategy=utils.get_test_strategy( + goal_id=goal_id, goal=utils.get_test_goal(id=goal_id)))), + ] + def setUp(self): super(TestStrategyObject, self).setUp() - self.fake_strategy = utils.get_test_strategy() + self.fake_goal = utils.create_test_goal(id=self.goal_id) - def test_get_by_id(self): + def eager_load_strategy_assert(self, strategy): + if self.eager: + self.assertIsNotNone(strategy.goal) + fields_to_check = set( + super(objects.Goal, objects.Goal).fields + ).symmetric_difference(objects.Goal.fields) + db_data = { + k: v for k, v in self.fake_goal.as_dict().items() + if k in fields_to_check} + object_data = { + k: v for k, v in strategy.goal.as_dict().items() + if k in fields_to_check} + self.assertEqual(db_data, object_data) + + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_get_by_id(self, mock_get_strategy): strategy_id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - strategy = objects.Strategy.get(self.context, strategy_id) - mock_get_strategy.assert_called_once_with(self.context, - strategy_id) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get( + self.context, strategy_id, eager=self.eager) + mock_get_strategy.assert_called_once_with( + self.context, strategy_id, eager=self.eager) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) - def test_get_by_uuid(self): + @mock.patch.object(db_api.Connection, 'get_strategy_by_uuid') + def test_get_by_uuid(self, mock_get_strategy): uuid = self.fake_strategy['uuid'] - with mock.patch.object(self.dbapi, 'get_strategy_by_uuid', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - strategy = objects.Strategy.get(self.context, uuid) - mock_get_strategy.assert_called_once_with(self.context, uuid) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get(self.context, uuid, eager=self.eager) + mock_get_strategy.assert_called_once_with( + self.context, uuid, eager=self.eager) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) def test_get_bad_uuid(self): self.assertRaises(exception.InvalidIdentity, objects.Strategy.get, self.context, 'not-a-uuid') - def test_list(self): - with mock.patch.object(self.dbapi, 'get_strategy_list', - autospec=True) as mock_get_list: - mock_get_list.return_value = [self.fake_strategy] - strategies = objects.Strategy.list(self.context) - self.assertEqual(1, mock_get_list.call_count, 1) - self.assertEqual(1, len(strategies)) - self.assertIsInstance(strategies[0], objects.Strategy) - self.assertEqual(self.context, strategies[0]._context) + @mock.patch.object(db_api.Connection, 'get_strategy_list') + def test_list(self, mock_get_list): + mock_get_list.return_value = [self.fake_strategy] + strategies = objects.Strategy.list(self.context, eager=self.eager) + self.assertEqual(1, mock_get_list.call_count, 1) + self.assertEqual(1, len(strategies)) + self.assertIsInstance(strategies[0], objects.Strategy) + self.assertEqual(self.context, strategies[0]._context) + for strategy in strategies: + self.eager_load_strategy_assert(strategy) - def test_create(self): - with mock.patch.object(self.dbapi, 'create_strategy', - autospec=True) as mock_create_strategy: - mock_create_strategy.return_value = self.fake_strategy - strategy = objects.Strategy(self.context, **self.fake_strategy) - - strategy.create() - mock_create_strategy.assert_called_once_with(self.fake_strategy) - self.assertEqual(self.context, strategy._context) - - def test_destroy(self): + @mock.patch.object(db_api.Connection, 'update_strategy') + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_save(self, mock_get_strategy, mock_update_strategy): _id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - with mock.patch.object(self.dbapi, 'destroy_strategy', - autospec=True) as mock_destroy_strategy: - strategy = objects.Strategy.get_by_id(self.context, _id) - strategy.destroy() - mock_get_strategy.assert_called_once_with(self.context, _id) - mock_destroy_strategy.assert_called_once_with(_id) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get_by_id( + self.context, _id, eager=self.eager) + strategy.name = 'UPDATED NAME' + strategy.save() - def test_save(self): - _id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - with mock.patch.object(self.dbapi, 'update_strategy', - autospec=True) as mock_update_strategy: - strategy = objects.Strategy.get_by_id(self.context, _id) - strategy.name = 'UPDATED NAME' - strategy.save() + mock_get_strategy.assert_called_once_with( + self.context, _id, eager=self.eager) + mock_update_strategy.assert_called_once_with( + _id, {'name': 'UPDATED NAME'}) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) - mock_get_strategy.assert_called_once_with(self.context, _id) - mock_update_strategy.assert_called_once_with( - _id, {'name': 'UPDATED NAME'}) - self.assertEqual(self.context, strategy._context) - - def test_refresh(self): + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_refresh(self, mock_get_strategy): _id = self.fake_strategy['id'] returns = [dict(self.fake_strategy, name="first name"), dict(self.fake_strategy, name="second name")] - expected = [mock.call(self.context, _id), - mock.call(self.context, _id)] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - side_effect=returns, - autospec=True) as mock_get_strategy: - strategy = objects.Strategy.get(self.context, _id) - self.assertEqual("first name", strategy.name) - strategy.refresh() - self.assertEqual("second name", strategy.name) - self.assertEqual(expected, mock_get_strategy.call_args_list) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.side_effect = returns + expected = [mock.call(self.context, _id, eager=self.eager), + mock.call(self.context, _id, eager=self.eager)] + strategy = objects.Strategy.get(self.context, _id, eager=self.eager) + self.assertEqual("first name", strategy.name) + strategy.refresh(eager=self.eager) + self.assertEqual("second name", strategy.name) + self.assertEqual(expected, mock_get_strategy.call_args_list) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) - def test_soft_delete(self): + +class TestCreateDeleteStrategyObject(base.DbTestCase): + + def setUp(self): + super(TestCreateDeleteStrategyObject, self).setUp() + self.fake_goal = utils.create_test_goal() + self.fake_strategy = utils.get_test_strategy(goal_id=self.fake_goal.id) + + @mock.patch.object(db_api.Connection, 'create_strategy') + def test_create(self, mock_create_strategy): + mock_create_strategy.return_value = self.fake_strategy + strategy = objects.Strategy(self.context, **self.fake_strategy) + strategy.create() + mock_create_strategy.assert_called_once_with(self.fake_strategy) + self.assertEqual(self.context, strategy._context) + + @mock.patch.object(db_api.Connection, 'soft_delete_strategy') + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_soft_delete(self, mock_get_strategy, mock_soft_delete): _id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - with mock.patch.object(self.dbapi, 'soft_delete_strategy', - autospec=True) as mock_soft_delete: - strategy = objects.Strategy.get_by_id(self.context, _id) - strategy.soft_delete() - mock_get_strategy.assert_called_once_with(self.context, _id) - mock_soft_delete.assert_called_once_with(_id) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get_by_id(self.context, _id) + strategy.soft_delete() + mock_get_strategy.assert_called_once_with( + self.context, _id, eager=False) + mock_soft_delete.assert_called_once_with(_id) + self.assertEqual(self.context, strategy._context) + + @mock.patch.object(db_api.Connection, 'destroy_strategy') + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_destroy(self, mock_get_strategy, mock_destroy_strategy): + _id = self.fake_strategy['id'] + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get_by_id(self.context, _id) + strategy.destroy() + mock_get_strategy.assert_called_once_with( + self.context, _id, eager=False) + mock_destroy_strategy.assert_called_once_with(_id) + self.assertEqual(self.context, strategy._context) diff --git a/watcher/tests/objects/utils.py b/watcher/tests/objects/utils.py index 873f3745f..6123087a7 100644 --- a/watcher/tests/objects/utils.py +++ b/watcher/tests/objects/utils.py @@ -217,6 +217,14 @@ def get_test_strategy(context, **kw): strategy = objects.Strategy(context) for key in db_strategy: setattr(strategy, key, db_strategy[key]) + + # ObjectField checks for the object type, so if we want to simulate a + # non-eager object loading, the field should not be referenced at all. + # Contrarily, eager loading need the data to be casted to the object type + # that was specified by the ObjectField. + if kw.get('goal'): + strategy.goal = objects.Goal(context, **kw.get('goal')) + return strategy