diff --git a/watcher/objects/base.py b/watcher/objects/base.py index d23e3fd48..97d3bed08 100644 --- a/watcher/objects/base.py +++ b/watcher/objects/base.py @@ -103,6 +103,19 @@ class WatcherPersistentObject(object): 'deleted_at': ovo_fields.DateTimeField(nullable=True), } + # Mapping between the object field name and a 2-tuple pair composed of + # its object type (e.g. objects.RelatedObject) and the the name of the + # model field related ID (or UUID) foreign key field. + # e.g.: + # + # fields = { + # # [...] + # 'related_object_id': fields.IntegerField(), # Foreign key + # 'related_object': wfields.ObjectField('RelatedObject'), + # } + # {'related_object': (objects.RelatedObject, 'related_object_id')} + object_fields = {} + def obj_refresh(self, loaded_object): """Applies updates for objects that inherit from base.WatcherObject. @@ -116,17 +129,39 @@ class WatcherPersistentObject(object): self[field] = loaded_object[field] @staticmethod - def _from_db_object(obj, db_object): + def _from_db_object(obj, db_object, eager=False): """Converts a database entity to a formal object. :param obj: An object of the class. :param db_object: A DB model of the object + :param eager: Enable the loading of object fields (Default: False) :return: The object of the class with the database entity added - """ + obj_class = type(obj) + object_fields = obj_class.object_fields for field in obj.fields: - obj[field] = db_object[field] + if field not in object_fields: + obj[field] = db_object[field] + + if eager: + # Load object fields + context = obj._context + loadable_fields = ( + (obj_field, related_obj_cls, rel_id) + for obj_field, (related_obj_cls, rel_id) + in object_fields.items() + if obj[rel_id] + ) + for obj_field, related_obj_cls, rel_id in loadable_fields: + if db_object.get(obj_field) and obj[rel_id]: + # The object field data was eagerly loaded alongside + # the main object data + obj[obj_field] = related_obj_cls._from_db_object( + related_obj_cls(context), db_object[obj_field]) + else: + # The object field data wasn't loaded yet + obj[obj_field] = related_obj_cls.get(context, obj[rel_id]) obj.obj_reset_changes() return obj diff --git a/watcher/objects/fields.py b/watcher/objects/fields.py index c1fe2061f..7b33bfae7 100644 --- a/watcher/objects/fields.py +++ b/watcher/objects/fields.py @@ -24,12 +24,13 @@ from oslo_versionedobjects import fields LOG = log.getLogger(__name__) -IntegerField = fields.IntegerField -UUIDField = fields.UUIDField -StringField = fields.StringField -DateTimeField = fields.DateTimeField BooleanField = fields.BooleanField +DateTimeField = fields.DateTimeField +IntegerField = fields.IntegerField ListOfStringsField = fields.ListOfStringsField +ObjectField = fields.ObjectField +StringField = fields.StringField +UUIDField = fields.UUIDField class Numeric(fields.FieldType): diff --git a/watcher/objects/strategy.py b/watcher/objects/strategy.py index 6f95a103e..584c8ff23 100644 --- a/watcher/objects/strategy.py +++ b/watcher/objects/strategy.py @@ -17,6 +17,7 @@ from watcher.common import exception from watcher.common import utils from watcher.db import api as db_api +from watcher import objects from watcher.objects import base from watcher.objects import fields as wfields @@ -25,6 +26,10 @@ from watcher.objects import fields as wfields class Strategy(base.WatcherPersistentObject, base.WatcherObject, base.WatcherObjectDictCompat): + # Version 1.0: Initial version + # Version 1.1: Added Goal object field + VERSION = '1.1' + dbapi = db_api.get_instance() fields = { @@ -34,10 +39,13 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, 'display_name': wfields.StringField(), 'goal_id': wfields.IntegerField(), 'parameters_spec': wfields.FlexibleDictField(nullable=True), + 'goal': wfields.ObjectField('Goal', nullable=True), } + object_fields = {'goal': (objects.Goal, 'goal_id')} + @base.remotable_classmethod - def get(cls, context, strategy_id): + def get(cls, context, strategy_id, eager=False): """Find a strategy based on its id or uuid :param context: Security context. NOTE: This should only @@ -47,17 +55,18 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Strategy(context) :param strategy_id: the id *or* uuid of a strategy. - :returns: a :class:`Strategy` object. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ if utils.is_int_like(strategy_id): - return cls.get_by_id(context, strategy_id) + return cls.get_by_id(context, strategy_id, eager=eager) elif utils.is_uuid_like(strategy_id): - return cls.get_by_uuid(context, strategy_id) + return cls.get_by_uuid(context, strategy_id, eager=eager) else: raise exception.InvalidIdentity(identity=strategy_id) @base.remotable_classmethod - def get_by_id(cls, context, strategy_id): + def get_by_id(cls, context, strategy_id, eager=False): """Find a strategy based on its integer id :param context: Security context. NOTE: This should only @@ -67,14 +76,16 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Strategy(context) :param strategy_id: the id of a strategy. - :returns: a :class:`Strategy` object. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ - db_strategy = cls.dbapi.get_strategy_by_id(context, strategy_id) - strategy = Strategy._from_db_object(cls(context), db_strategy) + db_strategy = cls.dbapi.get_strategy_by_id( + context, strategy_id, eager=eager) + strategy = cls._from_db_object(cls(context), db_strategy, eager=eager) return strategy @base.remotable_classmethod - def get_by_uuid(cls, context, uuid): + def get_by_uuid(cls, context, uuid, eager=False): """Find a strategy based on uuid :param context: Security context. NOTE: This should only @@ -84,29 +95,33 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, A context should be set when instantiating the object, e.g.: Strategy(context) :param uuid: the uuid of a strategy. - :returns: a :class:`Strategy` object. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ - db_strategy = cls.dbapi.get_strategy_by_uuid(context, uuid) - strategy = cls._from_db_object(cls(context), db_strategy) + db_strategy = cls.dbapi.get_strategy_by_uuid( + context, uuid, eager=eager) + strategy = cls._from_db_object(cls(context), db_strategy, eager=eager) return strategy @base.remotable_classmethod - def get_by_name(cls, context, name): + def get_by_name(cls, context, name, eager=False): """Find a strategy based on name - :param name: the name of a strategy. :param context: Security context - :returns: a :class:`Strategy` object. + :param name: the name of a strategy. + :param eager: Load object fields if True (Default: False) + :returns: A :class:`Strategy` object. """ - db_strategy = cls.dbapi.get_strategy_by_name(context, name) - strategy = cls._from_db_object(cls(context), db_strategy) + db_strategy = cls.dbapi.get_strategy_by_name( + context, name, eager=eager) + strategy = cls._from_db_object(cls(context), db_strategy, eager=eager) return strategy @base.remotable_classmethod def list(cls, context, limit=None, marker=None, filters=None, - sort_key=None, sort_dir=None): + sort_key=None, sort_dir=None, eager=False): """Return a list of :class:`Strategy` objects. :param context: Security context. NOTE: This should only @@ -115,11 +130,12 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, argument, even though we don't use it. A context should be set when instantiating the object, e.g.: Strategy(context) - :param filters: dict mapping the filter key to a value. :param limit: maximum number of resources to return in a single result. :param marker: pagination marker for large data sets. + :param filters: dict mapping the filter key to a value. :param sort_key: column to sort results by. - :param sort_dir: direction to sort. "asc" or "desc". + :param sort_dir: direction to sort. "asc" or "desc`". + :param eager: Load object fields if True (Default: False) :returns: a list of :class:`Strategy` object. """ db_strategies = cls.dbapi.get_strategy_list( @@ -130,7 +146,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, sort_key=sort_key, sort_dir=sort_dir) - return [cls._from_db_object(cls(context), obj) + return [cls._from_db_object(cls(context), obj, eager=eager) for obj in db_strategies] @base.remotable @@ -143,11 +159,14 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, argument, even though we don't use it. A context should be set when instantiating the object, e.g.: Strategy(context) + :returns: A :class:`Strategy` object. """ values = self.obj_get_changes() db_strategy = self.dbapi.create_strategy(values) - self._from_db_object(self, db_strategy) + # Note(v-francoise): Always load eagerly upon creation so we can send + # notifications containing information about the related relationships + self._from_db_object(self, db_strategy, eager=True) def destroy(self, context=None): """Delete the :class:`Strategy` from the DB. @@ -182,7 +201,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, self.obj_reset_changes() @base.remotable - def refresh(self, context=None): + def refresh(self, context=None, eager=False): """Loads updates for this :class:`Strategy`. Loads a strategy with the same uuid from the database and @@ -195,8 +214,10 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject, argument, even though we don't use it. A context should be set when instantiating the object, e.g.: Strategy(context) + :param eager: Load object fields if True (Default: False) """ - current = self.__class__.get_by_id(self._context, strategy_id=self.id) + current = self.__class__.get_by_id( + self._context, strategy_id=self.id, eager=eager) for field in self.fields: if (hasattr(self, base.get_attrname(field)) and self[field] != current[field]): diff --git a/watcher/tests/api/v1/test_strategies.py b/watcher/tests/api/v1/test_strategies.py index 34121a800..5be3fd448 100644 --- a/watcher/tests/api/v1/test_strategies.py +++ b/watcher/tests/api/v1/test_strategies.py @@ -21,6 +21,11 @@ from watcher.tests.objects import utils as obj_utils class TestListStrategy(api_base.FunctionalTest): + def setUp(self): + super(TestListStrategy, self).setUp() + self.fake_goal = obj_utils.create_test_goal( + self.context, uuid=utils.generate_uuid()) + def _assert_strategy_fields(self, strategy): strategy_fields = ['uuid', 'name', 'display_name', 'goal_uuid'] for field in strategy_fields: @@ -61,7 +66,6 @@ class TestListStrategy(api_base.FunctionalTest): self.assertEqual(404, response.status_int) def test_detail(self): - obj_utils.create_test_goal(self.context) strategy = obj_utils.create_test_strategy(self.context) response = self.get_json('/strategies/detail') self.assertEqual(strategy.uuid, response['strategies'][0]["uuid"]) @@ -78,7 +82,6 @@ class TestListStrategy(api_base.FunctionalTest): self.assertEqual(404, response.status_int) def test_many(self): - obj_utils.create_test_goal(self.context) strategy_list = [] for idx in range(1, 6): strategy = obj_utils.create_test_strategy( @@ -132,12 +135,12 @@ class TestListStrategy(api_base.FunctionalTest): def test_filter_by_goal_uuid(self): goal1 = obj_utils.create_test_goal( self.context, - id=1, + id=2, uuid=utils.generate_uuid(), name='My_Goal 1') goal2 = obj_utils.create_test_goal( self.context, - id=2, + id=3, uuid=utils.generate_uuid(), name='My Goal 2') @@ -164,12 +167,12 @@ class TestListStrategy(api_base.FunctionalTest): def test_filter_by_goal_name(self): goal1 = obj_utils.create_test_goal( self.context, - id=1, + id=2, uuid=utils.generate_uuid(), name='My_Goal 1') goal2 = obj_utils.create_test_goal( self.context, - id=2, + id=3, uuid=utils.generate_uuid(), name='My Goal 2') @@ -196,6 +199,11 @@ class TestListStrategy(api_base.FunctionalTest): class TestStrategyPolicyEnforcement(api_base.FunctionalTest): + def setUp(self): + super(TestStrategyPolicyEnforcement, self).setUp() + self.fake_goal = obj_utils.create_test_goal( + self.context, uuid=utils.generate_uuid()) + def _common_policy_check(self, rule, func, *arg, **kwarg): self.policy.set_rules({ "admin_api": "(role:admin or role:administrator)", @@ -227,8 +235,8 @@ class TestStrategyPolicyEnforcement(api_base.FunctionalTest): expect_errors=True) -class TestStrategyEnforcementWithAdminContext(TestListStrategy, - api_base.AdminRoleTest): +class TestStrategyEnforcementWithAdminContext( + TestListStrategy, api_base.AdminRoleTest): def setUp(self): super(TestStrategyEnforcementWithAdminContext, self).setUp() diff --git a/watcher/tests/db/utils.py b/watcher/tests/db/utils.py index ca5ef2dd9..ad91aa9cd 100644 --- a/watcher/tests/db/utils.py +++ b/watcher/tests/db/utils.py @@ -199,7 +199,7 @@ def create_test_scoring_engine(**kwargs): def get_test_strategy(**kwargs): - return { + strategy_data = { 'id': kwargs.get('id', 1), 'uuid': kwargs.get('uuid', 'cb3d0b58-4415-4d90-b75b-1e96878730e3'), 'name': kwargs.get('name', 'TEST'), @@ -208,9 +208,16 @@ def get_test_strategy(**kwargs): 'created_at': kwargs.get('created_at'), 'updated_at': kwargs.get('updated_at'), 'deleted_at': kwargs.get('deleted_at'), - 'parameters_spec': kwargs.get('parameters_spec', {}) + 'parameters_spec': kwargs.get('parameters_spec', {}), } + # goal ObjectField doesn't allow None nor dict, so if we want to simulate a + # non-eager object loading, the field should not be referenced at all. + if kwargs.get('goal'): + strategy_data['goal'] = kwargs.get('goal') + + return strategy_data + def get_test_service(**kwargs): return { diff --git a/watcher/tests/decision_engine/planner/test_default_planner.py b/watcher/tests/decision_engine/planner/test_default_planner.py index 30c7da5b9..34656f1ed 100644 --- a/watcher/tests/decision_engine/planner/test_default_planner.py +++ b/watcher/tests/decision_engine/planner/test_default_planner.py @@ -159,8 +159,11 @@ class TestDefaultPlanner(base.DbTestCase): 'migrate': 3 } - obj_utils.create_test_audit_template(self.context) - self.strategy = obj_utils.create_test_strategy(self.context) + self.goal = obj_utils.create_test_goal(self.context) + self.strategy = obj_utils.create_test_strategy( + self.context, goal_id=self.goal.id) + obj_utils.create_test_audit_template( + self.context, goal_id=self.goal.id, strategy_id=self.strategy.id) p = mock.patch.object(db_api.BaseConnection, 'create_action_plan') self.mock_create_action_plan = p.start() @@ -185,7 +188,8 @@ class TestDefaultPlanner(base.DbTestCase): @mock.patch.object(objects.Strategy, 'get_by_name') def test_schedule_scheduled_empty(self, m_get_by_name): m_get_by_name.return_value = self.strategy - audit = db_utils.create_test_audit(uuid=utils.generate_uuid()) + audit = db_utils.create_test_audit( + goal_id=self.goal.id, strategy_id=self.strategy.id) fake_solution = SolutionFakerSingleHyp.build() action_plan = self.default_planner.schedule(self.context, audit.id, fake_solution) @@ -194,8 +198,9 @@ class TestDefaultPlanner(base.DbTestCase): @mock.patch.object(objects.Strategy, 'get_by_name') def test_scheduler_warning_empty_action_plan(self, m_get_by_name): m_get_by_name.return_value = self.strategy - audit = db_utils.create_test_audit(uuid=utils.generate_uuid()) + audit = db_utils.create_test_audit( + goal_id=self.goal.id, strategy_id=self.strategy.id) fake_solution = SolutionFaker.build() - action_plan = self.default_planner.schedule(self.context, - audit.id, fake_solution) + action_plan = self.default_planner.schedule( + self.context, audit.id, fake_solution) self.assertIsNotNone(action_plan.uuid) diff --git a/watcher/tests/objects/test_objects.py b/watcher/tests/objects/test_objects.py index 165b71bdc..d71445875 100644 --- a/watcher/tests/objects/test_objects.py +++ b/watcher/tests/objects/test_objects.py @@ -410,7 +410,7 @@ class TestObject(_LocalTest, _TestObject): # The fingerprint values should only be changed if there is a version bump. expected_object_fingerprints = { 'Goal': '1.0-93881622db05e7b67a65ca885b4a022e', - 'Strategy': '1.0-e60f62cc854c6e63fb1c3befbfc8629e', + 'Strategy': '1.1-73f164491bdd4c034f48083a51bdeb7b', 'AuditTemplate': '1.0-7432ee4d3ce0c7cbb9d11a4565ee8eb6', 'Audit': '1.0-ebfc5360d019baf583a10a8a27071c97', 'ActionPlan': '1.0-cc76fd7f0e8479aeff817dd266341de4', diff --git a/watcher/tests/objects/test_strategy.py b/watcher/tests/objects/test_strategy.py index dc315ccbe..3d1a47dc5 100644 --- a/watcher/tests/objects/test_strategy.py +++ b/watcher/tests/objects/test_strategy.py @@ -15,7 +15,9 @@ import mock + from watcher.common import exception +from watcher.db.sqlalchemy import api as db_api from watcher import objects from watcher.tests.db import base from watcher.tests.db import utils @@ -23,107 +25,143 @@ from watcher.tests.db import utils class TestStrategyObject(base.DbTestCase): + goal_id = 2 + + scenarios = [ + ('non_eager', dict( + eager=False, fake_strategy=utils.get_test_strategy( + goal_id=goal_id))), + ('eager_with_non_eager_load', dict( + eager=True, fake_strategy=utils.get_test_strategy( + goal_id=goal_id))), + ('eager_with_eager_load', dict( + eager=True, fake_strategy=utils.get_test_strategy( + goal_id=goal_id, goal=utils.get_test_goal(id=goal_id)))), + ] + def setUp(self): super(TestStrategyObject, self).setUp() - self.fake_strategy = utils.get_test_strategy() + self.fake_goal = utils.create_test_goal(id=self.goal_id) - def test_get_by_id(self): + def eager_load_strategy_assert(self, strategy): + if self.eager: + self.assertIsNotNone(strategy.goal) + fields_to_check = set( + super(objects.Goal, objects.Goal).fields + ).symmetric_difference(objects.Goal.fields) + db_data = { + k: v for k, v in self.fake_goal.as_dict().items() + if k in fields_to_check} + object_data = { + k: v for k, v in strategy.goal.as_dict().items() + if k in fields_to_check} + self.assertEqual(db_data, object_data) + + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_get_by_id(self, mock_get_strategy): strategy_id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - strategy = objects.Strategy.get(self.context, strategy_id) - mock_get_strategy.assert_called_once_with(self.context, - strategy_id) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get( + self.context, strategy_id, eager=self.eager) + mock_get_strategy.assert_called_once_with( + self.context, strategy_id, eager=self.eager) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) - def test_get_by_uuid(self): + @mock.patch.object(db_api.Connection, 'get_strategy_by_uuid') + def test_get_by_uuid(self, mock_get_strategy): uuid = self.fake_strategy['uuid'] - with mock.patch.object(self.dbapi, 'get_strategy_by_uuid', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - strategy = objects.Strategy.get(self.context, uuid) - mock_get_strategy.assert_called_once_with(self.context, uuid) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get(self.context, uuid, eager=self.eager) + mock_get_strategy.assert_called_once_with( + self.context, uuid, eager=self.eager) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) def test_get_bad_uuid(self): self.assertRaises(exception.InvalidIdentity, objects.Strategy.get, self.context, 'not-a-uuid') - def test_list(self): - with mock.patch.object(self.dbapi, 'get_strategy_list', - autospec=True) as mock_get_list: - mock_get_list.return_value = [self.fake_strategy] - strategies = objects.Strategy.list(self.context) - self.assertEqual(1, mock_get_list.call_count, 1) - self.assertEqual(1, len(strategies)) - self.assertIsInstance(strategies[0], objects.Strategy) - self.assertEqual(self.context, strategies[0]._context) + @mock.patch.object(db_api.Connection, 'get_strategy_list') + def test_list(self, mock_get_list): + mock_get_list.return_value = [self.fake_strategy] + strategies = objects.Strategy.list(self.context, eager=self.eager) + self.assertEqual(1, mock_get_list.call_count, 1) + self.assertEqual(1, len(strategies)) + self.assertIsInstance(strategies[0], objects.Strategy) + self.assertEqual(self.context, strategies[0]._context) + for strategy in strategies: + self.eager_load_strategy_assert(strategy) - def test_create(self): - with mock.patch.object(self.dbapi, 'create_strategy', - autospec=True) as mock_create_strategy: - mock_create_strategy.return_value = self.fake_strategy - strategy = objects.Strategy(self.context, **self.fake_strategy) - - strategy.create() - mock_create_strategy.assert_called_once_with(self.fake_strategy) - self.assertEqual(self.context, strategy._context) - - def test_destroy(self): + @mock.patch.object(db_api.Connection, 'update_strategy') + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_save(self, mock_get_strategy, mock_update_strategy): _id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - with mock.patch.object(self.dbapi, 'destroy_strategy', - autospec=True) as mock_destroy_strategy: - strategy = objects.Strategy.get_by_id(self.context, _id) - strategy.destroy() - mock_get_strategy.assert_called_once_with(self.context, _id) - mock_destroy_strategy.assert_called_once_with(_id) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get_by_id( + self.context, _id, eager=self.eager) + strategy.name = 'UPDATED NAME' + strategy.save() - def test_save(self): - _id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - with mock.patch.object(self.dbapi, 'update_strategy', - autospec=True) as mock_update_strategy: - strategy = objects.Strategy.get_by_id(self.context, _id) - strategy.name = 'UPDATED NAME' - strategy.save() + mock_get_strategy.assert_called_once_with( + self.context, _id, eager=self.eager) + mock_update_strategy.assert_called_once_with( + _id, {'name': 'UPDATED NAME'}) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) - mock_get_strategy.assert_called_once_with(self.context, _id) - mock_update_strategy.assert_called_once_with( - _id, {'name': 'UPDATED NAME'}) - self.assertEqual(self.context, strategy._context) - - def test_refresh(self): + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_refresh(self, mock_get_strategy): _id = self.fake_strategy['id'] returns = [dict(self.fake_strategy, name="first name"), dict(self.fake_strategy, name="second name")] - expected = [mock.call(self.context, _id), - mock.call(self.context, _id)] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - side_effect=returns, - autospec=True) as mock_get_strategy: - strategy = objects.Strategy.get(self.context, _id) - self.assertEqual("first name", strategy.name) - strategy.refresh() - self.assertEqual("second name", strategy.name) - self.assertEqual(expected, mock_get_strategy.call_args_list) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.side_effect = returns + expected = [mock.call(self.context, _id, eager=self.eager), + mock.call(self.context, _id, eager=self.eager)] + strategy = objects.Strategy.get(self.context, _id, eager=self.eager) + self.assertEqual("first name", strategy.name) + strategy.refresh(eager=self.eager) + self.assertEqual("second name", strategy.name) + self.assertEqual(expected, mock_get_strategy.call_args_list) + self.assertEqual(self.context, strategy._context) + self.eager_load_strategy_assert(strategy) - def test_soft_delete(self): + +class TestCreateDeleteStrategyObject(base.DbTestCase): + + def setUp(self): + super(TestCreateDeleteStrategyObject, self).setUp() + self.fake_goal = utils.create_test_goal() + self.fake_strategy = utils.get_test_strategy(goal_id=self.fake_goal.id) + + @mock.patch.object(db_api.Connection, 'create_strategy') + def test_create(self, mock_create_strategy): + mock_create_strategy.return_value = self.fake_strategy + strategy = objects.Strategy(self.context, **self.fake_strategy) + strategy.create() + mock_create_strategy.assert_called_once_with(self.fake_strategy) + self.assertEqual(self.context, strategy._context) + + @mock.patch.object(db_api.Connection, 'soft_delete_strategy') + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_soft_delete(self, mock_get_strategy, mock_soft_delete): _id = self.fake_strategy['id'] - with mock.patch.object(self.dbapi, 'get_strategy_by_id', - autospec=True) as mock_get_strategy: - mock_get_strategy.return_value = self.fake_strategy - with mock.patch.object(self.dbapi, 'soft_delete_strategy', - autospec=True) as mock_soft_delete: - strategy = objects.Strategy.get_by_id(self.context, _id) - strategy.soft_delete() - mock_get_strategy.assert_called_once_with(self.context, _id) - mock_soft_delete.assert_called_once_with(_id) - self.assertEqual(self.context, strategy._context) + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get_by_id(self.context, _id) + strategy.soft_delete() + mock_get_strategy.assert_called_once_with( + self.context, _id, eager=False) + mock_soft_delete.assert_called_once_with(_id) + self.assertEqual(self.context, strategy._context) + + @mock.patch.object(db_api.Connection, 'destroy_strategy') + @mock.patch.object(db_api.Connection, 'get_strategy_by_id') + def test_destroy(self, mock_get_strategy, mock_destroy_strategy): + _id = self.fake_strategy['id'] + mock_get_strategy.return_value = self.fake_strategy + strategy = objects.Strategy.get_by_id(self.context, _id) + strategy.destroy() + mock_get_strategy.assert_called_once_with( + self.context, _id, eager=False) + mock_destroy_strategy.assert_called_once_with(_id) + self.assertEqual(self.context, strategy._context) diff --git a/watcher/tests/objects/utils.py b/watcher/tests/objects/utils.py index 873f3745f..6123087a7 100644 --- a/watcher/tests/objects/utils.py +++ b/watcher/tests/objects/utils.py @@ -217,6 +217,14 @@ def get_test_strategy(context, **kw): strategy = objects.Strategy(context) for key in db_strategy: setattr(strategy, key, db_strategy[key]) + + # ObjectField checks for the object type, so if we want to simulate a + # non-eager object loading, the field should not be referenced at all. + # Contrarily, eager loading need the data to be casted to the object type + # that was specified by the ObjectField. + if kw.get('goal'): + strategy.goal = objects.Goal(context, **kw.get('goal')) + return strategy