Merge "Added 'goal' ObjectField for Strategy object"

This commit is contained in:
Jenkins
2016-11-10 09:38:04 +00:00
committed by Gerrit Code Review
9 changed files with 255 additions and 132 deletions

View File

@@ -103,6 +103,19 @@ class WatcherPersistentObject(object):
'deleted_at': ovo_fields.DateTimeField(nullable=True),
}
# Mapping between the object field name and a 2-tuple pair composed of
# its object type (e.g. objects.RelatedObject) and the the name of the
# model field related ID (or UUID) foreign key field.
# e.g.:
#
# fields = {
# # [...]
# 'related_object_id': fields.IntegerField(), # Foreign key
# 'related_object': wfields.ObjectField('RelatedObject'),
# }
# {'related_object': (objects.RelatedObject, 'related_object_id')}
object_fields = {}
def obj_refresh(self, loaded_object):
"""Applies updates for objects that inherit from base.WatcherObject.
@@ -116,17 +129,39 @@ class WatcherPersistentObject(object):
self[field] = loaded_object[field]
@staticmethod
def _from_db_object(obj, db_object):
def _from_db_object(obj, db_object, eager=False):
"""Converts a database entity to a formal object.
:param obj: An object of the class.
:param db_object: A DB model of the object
:param eager: Enable the loading of object fields (Default: False)
:return: The object of the class with the database entity added
"""
obj_class = type(obj)
object_fields = obj_class.object_fields
for field in obj.fields:
obj[field] = db_object[field]
if field not in object_fields:
obj[field] = db_object[field]
if eager:
# Load object fields
context = obj._context
loadable_fields = (
(obj_field, related_obj_cls, rel_id)
for obj_field, (related_obj_cls, rel_id)
in object_fields.items()
if obj[rel_id]
)
for obj_field, related_obj_cls, rel_id in loadable_fields:
if db_object.get(obj_field) and obj[rel_id]:
# The object field data was eagerly loaded alongside
# the main object data
obj[obj_field] = related_obj_cls._from_db_object(
related_obj_cls(context), db_object[obj_field])
else:
# The object field data wasn't loaded yet
obj[obj_field] = related_obj_cls.get(context, obj[rel_id])
obj.obj_reset_changes()
return obj

View File

@@ -24,12 +24,13 @@ from oslo_versionedobjects import fields
LOG = log.getLogger(__name__)
IntegerField = fields.IntegerField
UUIDField = fields.UUIDField
StringField = fields.StringField
DateTimeField = fields.DateTimeField
BooleanField = fields.BooleanField
DateTimeField = fields.DateTimeField
IntegerField = fields.IntegerField
ListOfStringsField = fields.ListOfStringsField
ObjectField = fields.ObjectField
StringField = fields.StringField
UUIDField = fields.UUIDField
class Numeric(fields.FieldType):

View File

@@ -17,6 +17,7 @@
from watcher.common import exception
from watcher.common import utils
from watcher.db import api as db_api
from watcher import objects
from watcher.objects import base
from watcher.objects import fields as wfields
@@ -25,6 +26,10 @@ from watcher.objects import fields as wfields
class Strategy(base.WatcherPersistentObject, base.WatcherObject,
base.WatcherObjectDictCompat):
# Version 1.0: Initial version
# Version 1.1: Added Goal object field
VERSION = '1.1'
dbapi = db_api.get_instance()
fields = {
@@ -34,10 +39,13 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
'display_name': wfields.StringField(),
'goal_id': wfields.IntegerField(),
'parameters_spec': wfields.FlexibleDictField(nullable=True),
'goal': wfields.ObjectField('Goal', nullable=True),
}
object_fields = {'goal': (objects.Goal, 'goal_id')}
@base.remotable_classmethod
def get(cls, context, strategy_id):
def get(cls, context, strategy_id, eager=False):
"""Find a strategy based on its id or uuid
:param context: Security context. NOTE: This should only
@@ -47,17 +55,18 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
A context should be set when instantiating the
object, e.g.: Strategy(context)
:param strategy_id: the id *or* uuid of a strategy.
:returns: a :class:`Strategy` object.
:param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
"""
if utils.is_int_like(strategy_id):
return cls.get_by_id(context, strategy_id)
return cls.get_by_id(context, strategy_id, eager=eager)
elif utils.is_uuid_like(strategy_id):
return cls.get_by_uuid(context, strategy_id)
return cls.get_by_uuid(context, strategy_id, eager=eager)
else:
raise exception.InvalidIdentity(identity=strategy_id)
@base.remotable_classmethod
def get_by_id(cls, context, strategy_id):
def get_by_id(cls, context, strategy_id, eager=False):
"""Find a strategy based on its integer id
:param context: Security context. NOTE: This should only
@@ -67,14 +76,16 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
A context should be set when instantiating the
object, e.g.: Strategy(context)
:param strategy_id: the id of a strategy.
:returns: a :class:`Strategy` object.
:param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
"""
db_strategy = cls.dbapi.get_strategy_by_id(context, strategy_id)
strategy = Strategy._from_db_object(cls(context), db_strategy)
db_strategy = cls.dbapi.get_strategy_by_id(
context, strategy_id, eager=eager)
strategy = cls._from_db_object(cls(context), db_strategy, eager=eager)
return strategy
@base.remotable_classmethod
def get_by_uuid(cls, context, uuid):
def get_by_uuid(cls, context, uuid, eager=False):
"""Find a strategy based on uuid
:param context: Security context. NOTE: This should only
@@ -84,29 +95,33 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
A context should be set when instantiating the
object, e.g.: Strategy(context)
:param uuid: the uuid of a strategy.
:returns: a :class:`Strategy` object.
:param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
"""
db_strategy = cls.dbapi.get_strategy_by_uuid(context, uuid)
strategy = cls._from_db_object(cls(context), db_strategy)
db_strategy = cls.dbapi.get_strategy_by_uuid(
context, uuid, eager=eager)
strategy = cls._from_db_object(cls(context), db_strategy, eager=eager)
return strategy
@base.remotable_classmethod
def get_by_name(cls, context, name):
def get_by_name(cls, context, name, eager=False):
"""Find a strategy based on name
:param name: the name of a strategy.
:param context: Security context
:returns: a :class:`Strategy` object.
:param name: the name of a strategy.
:param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
"""
db_strategy = cls.dbapi.get_strategy_by_name(context, name)
strategy = cls._from_db_object(cls(context), db_strategy)
db_strategy = cls.dbapi.get_strategy_by_name(
context, name, eager=eager)
strategy = cls._from_db_object(cls(context), db_strategy, eager=eager)
return strategy
@base.remotable_classmethod
def list(cls, context, limit=None, marker=None, filters=None,
sort_key=None, sort_dir=None):
sort_key=None, sort_dir=None, eager=False):
"""Return a list of :class:`Strategy` objects.
:param context: Security context. NOTE: This should only
@@ -115,11 +130,12 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Strategy(context)
:param filters: dict mapping the filter key to a value.
:param limit: maximum number of resources to return in a single result.
:param marker: pagination marker for large data sets.
:param filters: dict mapping the filter key to a value.
:param sort_key: column to sort results by.
:param sort_dir: direction to sort. "asc" or "desc".
:param sort_dir: direction to sort. "asc" or "desc`".
:param eager: Load object fields if True (Default: False)
:returns: a list of :class:`Strategy` object.
"""
db_strategies = cls.dbapi.get_strategy_list(
@@ -130,7 +146,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
sort_key=sort_key,
sort_dir=sort_dir)
return [cls._from_db_object(cls(context), obj)
return [cls._from_db_object(cls(context), obj, eager=eager)
for obj in db_strategies]
@base.remotable
@@ -143,11 +159,14 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Strategy(context)
:returns: A :class:`Strategy` object.
"""
values = self.obj_get_changes()
db_strategy = self.dbapi.create_strategy(values)
self._from_db_object(self, db_strategy)
# Note(v-francoise): Always load eagerly upon creation so we can send
# notifications containing information about the related relationships
self._from_db_object(self, db_strategy, eager=True)
def destroy(self, context=None):
"""Delete the :class:`Strategy` from the DB.
@@ -182,7 +201,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
self.obj_reset_changes()
@base.remotable
def refresh(self, context=None):
def refresh(self, context=None, eager=False):
"""Loads updates for this :class:`Strategy`.
Loads a strategy with the same uuid from the database and
@@ -195,8 +214,10 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Strategy(context)
:param eager: Load object fields if True (Default: False)
"""
current = self.__class__.get_by_id(self._context, strategy_id=self.id)
current = self.__class__.get_by_id(
self._context, strategy_id=self.id, eager=eager)
for field in self.fields:
if (hasattr(self, base.get_attrname(field)) and
self[field] != current[field]):

View File

@@ -21,6 +21,11 @@ from watcher.tests.objects import utils as obj_utils
class TestListStrategy(api_base.FunctionalTest):
def setUp(self):
super(TestListStrategy, self).setUp()
self.fake_goal = obj_utils.create_test_goal(
self.context, uuid=utils.generate_uuid())
def _assert_strategy_fields(self, strategy):
strategy_fields = ['uuid', 'name', 'display_name', 'goal_uuid']
for field in strategy_fields:
@@ -61,7 +66,6 @@ class TestListStrategy(api_base.FunctionalTest):
self.assertEqual(404, response.status_int)
def test_detail(self):
obj_utils.create_test_goal(self.context)
strategy = obj_utils.create_test_strategy(self.context)
response = self.get_json('/strategies/detail')
self.assertEqual(strategy.uuid, response['strategies'][0]["uuid"])
@@ -78,7 +82,6 @@ class TestListStrategy(api_base.FunctionalTest):
self.assertEqual(404, response.status_int)
def test_many(self):
obj_utils.create_test_goal(self.context)
strategy_list = []
for idx in range(1, 6):
strategy = obj_utils.create_test_strategy(
@@ -132,12 +135,12 @@ class TestListStrategy(api_base.FunctionalTest):
def test_filter_by_goal_uuid(self):
goal1 = obj_utils.create_test_goal(
self.context,
id=1,
id=2,
uuid=utils.generate_uuid(),
name='My_Goal 1')
goal2 = obj_utils.create_test_goal(
self.context,
id=2,
id=3,
uuid=utils.generate_uuid(),
name='My Goal 2')
@@ -164,12 +167,12 @@ class TestListStrategy(api_base.FunctionalTest):
def test_filter_by_goal_name(self):
goal1 = obj_utils.create_test_goal(
self.context,
id=1,
id=2,
uuid=utils.generate_uuid(),
name='My_Goal 1')
goal2 = obj_utils.create_test_goal(
self.context,
id=2,
id=3,
uuid=utils.generate_uuid(),
name='My Goal 2')
@@ -196,6 +199,11 @@ class TestListStrategy(api_base.FunctionalTest):
class TestStrategyPolicyEnforcement(api_base.FunctionalTest):
def setUp(self):
super(TestStrategyPolicyEnforcement, self).setUp()
self.fake_goal = obj_utils.create_test_goal(
self.context, uuid=utils.generate_uuid())
def _common_policy_check(self, rule, func, *arg, **kwarg):
self.policy.set_rules({
"admin_api": "(role:admin or role:administrator)",
@@ -227,8 +235,8 @@ class TestStrategyPolicyEnforcement(api_base.FunctionalTest):
expect_errors=True)
class TestStrategyEnforcementWithAdminContext(TestListStrategy,
api_base.AdminRoleTest):
class TestStrategyEnforcementWithAdminContext(
TestListStrategy, api_base.AdminRoleTest):
def setUp(self):
super(TestStrategyEnforcementWithAdminContext, self).setUp()

View File

@@ -199,7 +199,7 @@ def create_test_scoring_engine(**kwargs):
def get_test_strategy(**kwargs):
return {
strategy_data = {
'id': kwargs.get('id', 1),
'uuid': kwargs.get('uuid', 'cb3d0b58-4415-4d90-b75b-1e96878730e3'),
'name': kwargs.get('name', 'TEST'),
@@ -208,9 +208,16 @@ def get_test_strategy(**kwargs):
'created_at': kwargs.get('created_at'),
'updated_at': kwargs.get('updated_at'),
'deleted_at': kwargs.get('deleted_at'),
'parameters_spec': kwargs.get('parameters_spec', {})
'parameters_spec': kwargs.get('parameters_spec', {}),
}
# goal ObjectField doesn't allow None nor dict, so if we want to simulate a
# non-eager object loading, the field should not be referenced at all.
if kwargs.get('goal'):
strategy_data['goal'] = kwargs.get('goal')
return strategy_data
def get_test_service(**kwargs):
return {

View File

@@ -159,8 +159,11 @@ class TestDefaultPlanner(base.DbTestCase):
'migrate': 3
}
obj_utils.create_test_audit_template(self.context)
self.strategy = obj_utils.create_test_strategy(self.context)
self.goal = obj_utils.create_test_goal(self.context)
self.strategy = obj_utils.create_test_strategy(
self.context, goal_id=self.goal.id)
obj_utils.create_test_audit_template(
self.context, goal_id=self.goal.id, strategy_id=self.strategy.id)
p = mock.patch.object(db_api.BaseConnection, 'create_action_plan')
self.mock_create_action_plan = p.start()
@@ -185,7 +188,8 @@ class TestDefaultPlanner(base.DbTestCase):
@mock.patch.object(objects.Strategy, 'get_by_name')
def test_schedule_scheduled_empty(self, m_get_by_name):
m_get_by_name.return_value = self.strategy
audit = db_utils.create_test_audit(uuid=utils.generate_uuid())
audit = db_utils.create_test_audit(
goal_id=self.goal.id, strategy_id=self.strategy.id)
fake_solution = SolutionFakerSingleHyp.build()
action_plan = self.default_planner.schedule(self.context,
audit.id, fake_solution)
@@ -194,8 +198,9 @@ class TestDefaultPlanner(base.DbTestCase):
@mock.patch.object(objects.Strategy, 'get_by_name')
def test_scheduler_warning_empty_action_plan(self, m_get_by_name):
m_get_by_name.return_value = self.strategy
audit = db_utils.create_test_audit(uuid=utils.generate_uuid())
audit = db_utils.create_test_audit(
goal_id=self.goal.id, strategy_id=self.strategy.id)
fake_solution = SolutionFaker.build()
action_plan = self.default_planner.schedule(self.context,
audit.id, fake_solution)
action_plan = self.default_planner.schedule(
self.context, audit.id, fake_solution)
self.assertIsNotNone(action_plan.uuid)

View File

@@ -410,7 +410,7 @@ class TestObject(_LocalTest, _TestObject):
# The fingerprint values should only be changed if there is a version bump.
expected_object_fingerprints = {
'Goal': '1.0-93881622db05e7b67a65ca885b4a022e',
'Strategy': '1.0-e60f62cc854c6e63fb1c3befbfc8629e',
'Strategy': '1.1-73f164491bdd4c034f48083a51bdeb7b',
'AuditTemplate': '1.0-7432ee4d3ce0c7cbb9d11a4565ee8eb6',
'Audit': '1.0-ebfc5360d019baf583a10a8a27071c97',
'ActionPlan': '1.0-cc76fd7f0e8479aeff817dd266341de4',

View File

@@ -15,7 +15,9 @@
import mock
from watcher.common import exception
from watcher.db.sqlalchemy import api as db_api
from watcher import objects
from watcher.tests.db import base
from watcher.tests.db import utils
@@ -23,107 +25,143 @@ from watcher.tests.db import utils
class TestStrategyObject(base.DbTestCase):
goal_id = 2
scenarios = [
('non_eager', dict(
eager=False, fake_strategy=utils.get_test_strategy(
goal_id=goal_id))),
('eager_with_non_eager_load', dict(
eager=True, fake_strategy=utils.get_test_strategy(
goal_id=goal_id))),
('eager_with_eager_load', dict(
eager=True, fake_strategy=utils.get_test_strategy(
goal_id=goal_id, goal=utils.get_test_goal(id=goal_id)))),
]
def setUp(self):
super(TestStrategyObject, self).setUp()
self.fake_strategy = utils.get_test_strategy()
self.fake_goal = utils.create_test_goal(id=self.goal_id)
def test_get_by_id(self):
def eager_load_strategy_assert(self, strategy):
if self.eager:
self.assertIsNotNone(strategy.goal)
fields_to_check = set(
super(objects.Goal, objects.Goal).fields
).symmetric_difference(objects.Goal.fields)
db_data = {
k: v for k, v in self.fake_goal.as_dict().items()
if k in fields_to_check}
object_data = {
k: v for k, v in strategy.goal.as_dict().items()
if k in fields_to_check}
self.assertEqual(db_data, object_data)
@mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_get_by_id(self, mock_get_strategy):
strategy_id = self.fake_strategy['id']
with mock.patch.object(self.dbapi, 'get_strategy_by_id',
autospec=True) as mock_get_strategy:
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get(self.context, strategy_id)
mock_get_strategy.assert_called_once_with(self.context,
strategy_id)
self.assertEqual(self.context, strategy._context)
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get(
self.context, strategy_id, eager=self.eager)
mock_get_strategy.assert_called_once_with(
self.context, strategy_id, eager=self.eager)
self.assertEqual(self.context, strategy._context)
self.eager_load_strategy_assert(strategy)
def test_get_by_uuid(self):
@mock.patch.object(db_api.Connection, 'get_strategy_by_uuid')
def test_get_by_uuid(self, mock_get_strategy):
uuid = self.fake_strategy['uuid']
with mock.patch.object(self.dbapi, 'get_strategy_by_uuid',
autospec=True) as mock_get_strategy:
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get(self.context, uuid)
mock_get_strategy.assert_called_once_with(self.context, uuid)
self.assertEqual(self.context, strategy._context)
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get(self.context, uuid, eager=self.eager)
mock_get_strategy.assert_called_once_with(
self.context, uuid, eager=self.eager)
self.assertEqual(self.context, strategy._context)
self.eager_load_strategy_assert(strategy)
def test_get_bad_uuid(self):
self.assertRaises(exception.InvalidIdentity,
objects.Strategy.get, self.context, 'not-a-uuid')
def test_list(self):
with mock.patch.object(self.dbapi, 'get_strategy_list',
autospec=True) as mock_get_list:
mock_get_list.return_value = [self.fake_strategy]
strategies = objects.Strategy.list(self.context)
self.assertEqual(1, mock_get_list.call_count, 1)
self.assertEqual(1, len(strategies))
self.assertIsInstance(strategies[0], objects.Strategy)
self.assertEqual(self.context, strategies[0]._context)
@mock.patch.object(db_api.Connection, 'get_strategy_list')
def test_list(self, mock_get_list):
mock_get_list.return_value = [self.fake_strategy]
strategies = objects.Strategy.list(self.context, eager=self.eager)
self.assertEqual(1, mock_get_list.call_count, 1)
self.assertEqual(1, len(strategies))
self.assertIsInstance(strategies[0], objects.Strategy)
self.assertEqual(self.context, strategies[0]._context)
for strategy in strategies:
self.eager_load_strategy_assert(strategy)
def test_create(self):
with mock.patch.object(self.dbapi, 'create_strategy',
autospec=True) as mock_create_strategy:
mock_create_strategy.return_value = self.fake_strategy
strategy = objects.Strategy(self.context, **self.fake_strategy)
strategy.create()
mock_create_strategy.assert_called_once_with(self.fake_strategy)
self.assertEqual(self.context, strategy._context)
def test_destroy(self):
@mock.patch.object(db_api.Connection, 'update_strategy')
@mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_save(self, mock_get_strategy, mock_update_strategy):
_id = self.fake_strategy['id']
with mock.patch.object(self.dbapi, 'get_strategy_by_id',
autospec=True) as mock_get_strategy:
mock_get_strategy.return_value = self.fake_strategy
with mock.patch.object(self.dbapi, 'destroy_strategy',
autospec=True) as mock_destroy_strategy:
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.destroy()
mock_get_strategy.assert_called_once_with(self.context, _id)
mock_destroy_strategy.assert_called_once_with(_id)
self.assertEqual(self.context, strategy._context)
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get_by_id(
self.context, _id, eager=self.eager)
strategy.name = 'UPDATED NAME'
strategy.save()
def test_save(self):
_id = self.fake_strategy['id']
with mock.patch.object(self.dbapi, 'get_strategy_by_id',
autospec=True) as mock_get_strategy:
mock_get_strategy.return_value = self.fake_strategy
with mock.patch.object(self.dbapi, 'update_strategy',
autospec=True) as mock_update_strategy:
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.name = 'UPDATED NAME'
strategy.save()
mock_get_strategy.assert_called_once_with(
self.context, _id, eager=self.eager)
mock_update_strategy.assert_called_once_with(
_id, {'name': 'UPDATED NAME'})
self.assertEqual(self.context, strategy._context)
self.eager_load_strategy_assert(strategy)
mock_get_strategy.assert_called_once_with(self.context, _id)
mock_update_strategy.assert_called_once_with(
_id, {'name': 'UPDATED NAME'})
self.assertEqual(self.context, strategy._context)
def test_refresh(self):
@mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_refresh(self, mock_get_strategy):
_id = self.fake_strategy['id']
returns = [dict(self.fake_strategy, name="first name"),
dict(self.fake_strategy, name="second name")]
expected = [mock.call(self.context, _id),
mock.call(self.context, _id)]
with mock.patch.object(self.dbapi, 'get_strategy_by_id',
side_effect=returns,
autospec=True) as mock_get_strategy:
strategy = objects.Strategy.get(self.context, _id)
self.assertEqual("first name", strategy.name)
strategy.refresh()
self.assertEqual("second name", strategy.name)
self.assertEqual(expected, mock_get_strategy.call_args_list)
self.assertEqual(self.context, strategy._context)
mock_get_strategy.side_effect = returns
expected = [mock.call(self.context, _id, eager=self.eager),
mock.call(self.context, _id, eager=self.eager)]
strategy = objects.Strategy.get(self.context, _id, eager=self.eager)
self.assertEqual("first name", strategy.name)
strategy.refresh(eager=self.eager)
self.assertEqual("second name", strategy.name)
self.assertEqual(expected, mock_get_strategy.call_args_list)
self.assertEqual(self.context, strategy._context)
self.eager_load_strategy_assert(strategy)
def test_soft_delete(self):
class TestCreateDeleteStrategyObject(base.DbTestCase):
def setUp(self):
super(TestCreateDeleteStrategyObject, self).setUp()
self.fake_goal = utils.create_test_goal()
self.fake_strategy = utils.get_test_strategy(goal_id=self.fake_goal.id)
@mock.patch.object(db_api.Connection, 'create_strategy')
def test_create(self, mock_create_strategy):
mock_create_strategy.return_value = self.fake_strategy
strategy = objects.Strategy(self.context, **self.fake_strategy)
strategy.create()
mock_create_strategy.assert_called_once_with(self.fake_strategy)
self.assertEqual(self.context, strategy._context)
@mock.patch.object(db_api.Connection, 'soft_delete_strategy')
@mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_soft_delete(self, mock_get_strategy, mock_soft_delete):
_id = self.fake_strategy['id']
with mock.patch.object(self.dbapi, 'get_strategy_by_id',
autospec=True) as mock_get_strategy:
mock_get_strategy.return_value = self.fake_strategy
with mock.patch.object(self.dbapi, 'soft_delete_strategy',
autospec=True) as mock_soft_delete:
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.soft_delete()
mock_get_strategy.assert_called_once_with(self.context, _id)
mock_soft_delete.assert_called_once_with(_id)
self.assertEqual(self.context, strategy._context)
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.soft_delete()
mock_get_strategy.assert_called_once_with(
self.context, _id, eager=False)
mock_soft_delete.assert_called_once_with(_id)
self.assertEqual(self.context, strategy._context)
@mock.patch.object(db_api.Connection, 'destroy_strategy')
@mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_destroy(self, mock_get_strategy, mock_destroy_strategy):
_id = self.fake_strategy['id']
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.destroy()
mock_get_strategy.assert_called_once_with(
self.context, _id, eager=False)
mock_destroy_strategy.assert_called_once_with(_id)
self.assertEqual(self.context, strategy._context)

View File

@@ -217,6 +217,14 @@ def get_test_strategy(context, **kw):
strategy = objects.Strategy(context)
for key in db_strategy:
setattr(strategy, key, db_strategy[key])
# ObjectField checks for the object type, so if we want to simulate a
# non-eager object loading, the field should not be referenced at all.
# Contrarily, eager loading need the data to be casted to the object type
# that was specified by the ObjectField.
if kw.get('goal'):
strategy.goal = objects.Goal(context, **kw.get('goal'))
return strategy