Merge "Added 'goal' ObjectField for Strategy object"

This commit is contained in:
Jenkins
2016-11-10 09:38:04 +00:00
committed by Gerrit Code Review
9 changed files with 255 additions and 132 deletions

View File

@@ -103,6 +103,19 @@ class WatcherPersistentObject(object):
'deleted_at': ovo_fields.DateTimeField(nullable=True), 'deleted_at': ovo_fields.DateTimeField(nullable=True),
} }
# Mapping between the object field name and a 2-tuple pair composed of
# its object type (e.g. objects.RelatedObject) and the the name of the
# model field related ID (or UUID) foreign key field.
# e.g.:
#
# fields = {
# # [...]
# 'related_object_id': fields.IntegerField(), # Foreign key
# 'related_object': wfields.ObjectField('RelatedObject'),
# }
# {'related_object': (objects.RelatedObject, 'related_object_id')}
object_fields = {}
def obj_refresh(self, loaded_object): def obj_refresh(self, loaded_object):
"""Applies updates for objects that inherit from base.WatcherObject. """Applies updates for objects that inherit from base.WatcherObject.
@@ -116,17 +129,39 @@ class WatcherPersistentObject(object):
self[field] = loaded_object[field] self[field] = loaded_object[field]
@staticmethod @staticmethod
def _from_db_object(obj, db_object): def _from_db_object(obj, db_object, eager=False):
"""Converts a database entity to a formal object. """Converts a database entity to a formal object.
:param obj: An object of the class. :param obj: An object of the class.
:param db_object: A DB model of the object :param db_object: A DB model of the object
:param eager: Enable the loading of object fields (Default: False)
:return: The object of the class with the database entity added :return: The object of the class with the database entity added
""" """
obj_class = type(obj)
object_fields = obj_class.object_fields
for field in obj.fields: for field in obj.fields:
obj[field] = db_object[field] if field not in object_fields:
obj[field] = db_object[field]
if eager:
# Load object fields
context = obj._context
loadable_fields = (
(obj_field, related_obj_cls, rel_id)
for obj_field, (related_obj_cls, rel_id)
in object_fields.items()
if obj[rel_id]
)
for obj_field, related_obj_cls, rel_id in loadable_fields:
if db_object.get(obj_field) and obj[rel_id]:
# The object field data was eagerly loaded alongside
# the main object data
obj[obj_field] = related_obj_cls._from_db_object(
related_obj_cls(context), db_object[obj_field])
else:
# The object field data wasn't loaded yet
obj[obj_field] = related_obj_cls.get(context, obj[rel_id])
obj.obj_reset_changes() obj.obj_reset_changes()
return obj return obj

View File

@@ -24,12 +24,13 @@ from oslo_versionedobjects import fields
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
IntegerField = fields.IntegerField
UUIDField = fields.UUIDField
StringField = fields.StringField
DateTimeField = fields.DateTimeField
BooleanField = fields.BooleanField BooleanField = fields.BooleanField
DateTimeField = fields.DateTimeField
IntegerField = fields.IntegerField
ListOfStringsField = fields.ListOfStringsField ListOfStringsField = fields.ListOfStringsField
ObjectField = fields.ObjectField
StringField = fields.StringField
UUIDField = fields.UUIDField
class Numeric(fields.FieldType): class Numeric(fields.FieldType):

View File

@@ -17,6 +17,7 @@
from watcher.common import exception from watcher.common import exception
from watcher.common import utils from watcher.common import utils
from watcher.db import api as db_api from watcher.db import api as db_api
from watcher import objects
from watcher.objects import base from watcher.objects import base
from watcher.objects import fields as wfields from watcher.objects import fields as wfields
@@ -25,6 +26,10 @@ from watcher.objects import fields as wfields
class Strategy(base.WatcherPersistentObject, base.WatcherObject, class Strategy(base.WatcherPersistentObject, base.WatcherObject,
base.WatcherObjectDictCompat): base.WatcherObjectDictCompat):
# Version 1.0: Initial version
# Version 1.1: Added Goal object field
VERSION = '1.1'
dbapi = db_api.get_instance() dbapi = db_api.get_instance()
fields = { fields = {
@@ -34,10 +39,13 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
'display_name': wfields.StringField(), 'display_name': wfields.StringField(),
'goal_id': wfields.IntegerField(), 'goal_id': wfields.IntegerField(),
'parameters_spec': wfields.FlexibleDictField(nullable=True), 'parameters_spec': wfields.FlexibleDictField(nullable=True),
'goal': wfields.ObjectField('Goal', nullable=True),
} }
object_fields = {'goal': (objects.Goal, 'goal_id')}
@base.remotable_classmethod @base.remotable_classmethod
def get(cls, context, strategy_id): def get(cls, context, strategy_id, eager=False):
"""Find a strategy based on its id or uuid """Find a strategy based on its id or uuid
:param context: Security context. NOTE: This should only :param context: Security context. NOTE: This should only
@@ -47,17 +55,18 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
A context should be set when instantiating the A context should be set when instantiating the
object, e.g.: Strategy(context) object, e.g.: Strategy(context)
:param strategy_id: the id *or* uuid of a strategy. :param strategy_id: the id *or* uuid of a strategy.
:returns: a :class:`Strategy` object. :param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
""" """
if utils.is_int_like(strategy_id): if utils.is_int_like(strategy_id):
return cls.get_by_id(context, strategy_id) return cls.get_by_id(context, strategy_id, eager=eager)
elif utils.is_uuid_like(strategy_id): elif utils.is_uuid_like(strategy_id):
return cls.get_by_uuid(context, strategy_id) return cls.get_by_uuid(context, strategy_id, eager=eager)
else: else:
raise exception.InvalidIdentity(identity=strategy_id) raise exception.InvalidIdentity(identity=strategy_id)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_id(cls, context, strategy_id): def get_by_id(cls, context, strategy_id, eager=False):
"""Find a strategy based on its integer id """Find a strategy based on its integer id
:param context: Security context. NOTE: This should only :param context: Security context. NOTE: This should only
@@ -67,14 +76,16 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
A context should be set when instantiating the A context should be set when instantiating the
object, e.g.: Strategy(context) object, e.g.: Strategy(context)
:param strategy_id: the id of a strategy. :param strategy_id: the id of a strategy.
:returns: a :class:`Strategy` object. :param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
""" """
db_strategy = cls.dbapi.get_strategy_by_id(context, strategy_id) db_strategy = cls.dbapi.get_strategy_by_id(
strategy = Strategy._from_db_object(cls(context), db_strategy) context, strategy_id, eager=eager)
strategy = cls._from_db_object(cls(context), db_strategy, eager=eager)
return strategy return strategy
@base.remotable_classmethod @base.remotable_classmethod
def get_by_uuid(cls, context, uuid): def get_by_uuid(cls, context, uuid, eager=False):
"""Find a strategy based on uuid """Find a strategy based on uuid
:param context: Security context. NOTE: This should only :param context: Security context. NOTE: This should only
@@ -84,29 +95,33 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
A context should be set when instantiating the A context should be set when instantiating the
object, e.g.: Strategy(context) object, e.g.: Strategy(context)
:param uuid: the uuid of a strategy. :param uuid: the uuid of a strategy.
:returns: a :class:`Strategy` object. :param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
""" """
db_strategy = cls.dbapi.get_strategy_by_uuid(context, uuid) db_strategy = cls.dbapi.get_strategy_by_uuid(
strategy = cls._from_db_object(cls(context), db_strategy) context, uuid, eager=eager)
strategy = cls._from_db_object(cls(context), db_strategy, eager=eager)
return strategy return strategy
@base.remotable_classmethod @base.remotable_classmethod
def get_by_name(cls, context, name): def get_by_name(cls, context, name, eager=False):
"""Find a strategy based on name """Find a strategy based on name
:param name: the name of a strategy.
:param context: Security context :param context: Security context
:returns: a :class:`Strategy` object. :param name: the name of a strategy.
:param eager: Load object fields if True (Default: False)
:returns: A :class:`Strategy` object.
""" """
db_strategy = cls.dbapi.get_strategy_by_name(context, name) db_strategy = cls.dbapi.get_strategy_by_name(
strategy = cls._from_db_object(cls(context), db_strategy) context, name, eager=eager)
strategy = cls._from_db_object(cls(context), db_strategy, eager=eager)
return strategy return strategy
@base.remotable_classmethod @base.remotable_classmethod
def list(cls, context, limit=None, marker=None, filters=None, def list(cls, context, limit=None, marker=None, filters=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None, eager=False):
"""Return a list of :class:`Strategy` objects. """Return a list of :class:`Strategy` objects.
:param context: Security context. NOTE: This should only :param context: Security context. NOTE: This should only
@@ -115,11 +130,12 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
argument, even though we don't use it. argument, even though we don't use it.
A context should be set when instantiating the A context should be set when instantiating the
object, e.g.: Strategy(context) object, e.g.: Strategy(context)
:param filters: dict mapping the filter key to a value.
:param limit: maximum number of resources to return in a single result. :param limit: maximum number of resources to return in a single result.
:param marker: pagination marker for large data sets. :param marker: pagination marker for large data sets.
:param filters: dict mapping the filter key to a value.
:param sort_key: column to sort results by. :param sort_key: column to sort results by.
:param sort_dir: direction to sort. "asc" or "desc". :param sort_dir: direction to sort. "asc" or "desc`".
:param eager: Load object fields if True (Default: False)
:returns: a list of :class:`Strategy` object. :returns: a list of :class:`Strategy` object.
""" """
db_strategies = cls.dbapi.get_strategy_list( db_strategies = cls.dbapi.get_strategy_list(
@@ -130,7 +146,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
sort_key=sort_key, sort_key=sort_key,
sort_dir=sort_dir) sort_dir=sort_dir)
return [cls._from_db_object(cls(context), obj) return [cls._from_db_object(cls(context), obj, eager=eager)
for obj in db_strategies] for obj in db_strategies]
@base.remotable @base.remotable
@@ -143,11 +159,14 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
argument, even though we don't use it. argument, even though we don't use it.
A context should be set when instantiating the A context should be set when instantiating the
object, e.g.: Strategy(context) object, e.g.: Strategy(context)
:returns: A :class:`Strategy` object.
""" """
values = self.obj_get_changes() values = self.obj_get_changes()
db_strategy = self.dbapi.create_strategy(values) db_strategy = self.dbapi.create_strategy(values)
self._from_db_object(self, db_strategy) # Note(v-francoise): Always load eagerly upon creation so we can send
# notifications containing information about the related relationships
self._from_db_object(self, db_strategy, eager=True)
def destroy(self, context=None): def destroy(self, context=None):
"""Delete the :class:`Strategy` from the DB. """Delete the :class:`Strategy` from the DB.
@@ -182,7 +201,7 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
self.obj_reset_changes() self.obj_reset_changes()
@base.remotable @base.remotable
def refresh(self, context=None): def refresh(self, context=None, eager=False):
"""Loads updates for this :class:`Strategy`. """Loads updates for this :class:`Strategy`.
Loads a strategy with the same uuid from the database and Loads a strategy with the same uuid from the database and
@@ -195,8 +214,10 @@ class Strategy(base.WatcherPersistentObject, base.WatcherObject,
argument, even though we don't use it. argument, even though we don't use it.
A context should be set when instantiating the A context should be set when instantiating the
object, e.g.: Strategy(context) object, e.g.: Strategy(context)
:param eager: Load object fields if True (Default: False)
""" """
current = self.__class__.get_by_id(self._context, strategy_id=self.id) current = self.__class__.get_by_id(
self._context, strategy_id=self.id, eager=eager)
for field in self.fields: for field in self.fields:
if (hasattr(self, base.get_attrname(field)) and if (hasattr(self, base.get_attrname(field)) and
self[field] != current[field]): self[field] != current[field]):

View File

@@ -21,6 +21,11 @@ from watcher.tests.objects import utils as obj_utils
class TestListStrategy(api_base.FunctionalTest): class TestListStrategy(api_base.FunctionalTest):
def setUp(self):
super(TestListStrategy, self).setUp()
self.fake_goal = obj_utils.create_test_goal(
self.context, uuid=utils.generate_uuid())
def _assert_strategy_fields(self, strategy): def _assert_strategy_fields(self, strategy):
strategy_fields = ['uuid', 'name', 'display_name', 'goal_uuid'] strategy_fields = ['uuid', 'name', 'display_name', 'goal_uuid']
for field in strategy_fields: for field in strategy_fields:
@@ -61,7 +66,6 @@ class TestListStrategy(api_base.FunctionalTest):
self.assertEqual(404, response.status_int) self.assertEqual(404, response.status_int)
def test_detail(self): def test_detail(self):
obj_utils.create_test_goal(self.context)
strategy = obj_utils.create_test_strategy(self.context) strategy = obj_utils.create_test_strategy(self.context)
response = self.get_json('/strategies/detail') response = self.get_json('/strategies/detail')
self.assertEqual(strategy.uuid, response['strategies'][0]["uuid"]) self.assertEqual(strategy.uuid, response['strategies'][0]["uuid"])
@@ -78,7 +82,6 @@ class TestListStrategy(api_base.FunctionalTest):
self.assertEqual(404, response.status_int) self.assertEqual(404, response.status_int)
def test_many(self): def test_many(self):
obj_utils.create_test_goal(self.context)
strategy_list = [] strategy_list = []
for idx in range(1, 6): for idx in range(1, 6):
strategy = obj_utils.create_test_strategy( strategy = obj_utils.create_test_strategy(
@@ -132,12 +135,12 @@ class TestListStrategy(api_base.FunctionalTest):
def test_filter_by_goal_uuid(self): def test_filter_by_goal_uuid(self):
goal1 = obj_utils.create_test_goal( goal1 = obj_utils.create_test_goal(
self.context, self.context,
id=1, id=2,
uuid=utils.generate_uuid(), uuid=utils.generate_uuid(),
name='My_Goal 1') name='My_Goal 1')
goal2 = obj_utils.create_test_goal( goal2 = obj_utils.create_test_goal(
self.context, self.context,
id=2, id=3,
uuid=utils.generate_uuid(), uuid=utils.generate_uuid(),
name='My Goal 2') name='My Goal 2')
@@ -164,12 +167,12 @@ class TestListStrategy(api_base.FunctionalTest):
def test_filter_by_goal_name(self): def test_filter_by_goal_name(self):
goal1 = obj_utils.create_test_goal( goal1 = obj_utils.create_test_goal(
self.context, self.context,
id=1, id=2,
uuid=utils.generate_uuid(), uuid=utils.generate_uuid(),
name='My_Goal 1') name='My_Goal 1')
goal2 = obj_utils.create_test_goal( goal2 = obj_utils.create_test_goal(
self.context, self.context,
id=2, id=3,
uuid=utils.generate_uuid(), uuid=utils.generate_uuid(),
name='My Goal 2') name='My Goal 2')
@@ -196,6 +199,11 @@ class TestListStrategy(api_base.FunctionalTest):
class TestStrategyPolicyEnforcement(api_base.FunctionalTest): class TestStrategyPolicyEnforcement(api_base.FunctionalTest):
def setUp(self):
super(TestStrategyPolicyEnforcement, self).setUp()
self.fake_goal = obj_utils.create_test_goal(
self.context, uuid=utils.generate_uuid())
def _common_policy_check(self, rule, func, *arg, **kwarg): def _common_policy_check(self, rule, func, *arg, **kwarg):
self.policy.set_rules({ self.policy.set_rules({
"admin_api": "(role:admin or role:administrator)", "admin_api": "(role:admin or role:administrator)",
@@ -227,8 +235,8 @@ class TestStrategyPolicyEnforcement(api_base.FunctionalTest):
expect_errors=True) expect_errors=True)
class TestStrategyEnforcementWithAdminContext(TestListStrategy, class TestStrategyEnforcementWithAdminContext(
api_base.AdminRoleTest): TestListStrategy, api_base.AdminRoleTest):
def setUp(self): def setUp(self):
super(TestStrategyEnforcementWithAdminContext, self).setUp() super(TestStrategyEnforcementWithAdminContext, self).setUp()

View File

@@ -199,7 +199,7 @@ def create_test_scoring_engine(**kwargs):
def get_test_strategy(**kwargs): def get_test_strategy(**kwargs):
return { strategy_data = {
'id': kwargs.get('id', 1), 'id': kwargs.get('id', 1),
'uuid': kwargs.get('uuid', 'cb3d0b58-4415-4d90-b75b-1e96878730e3'), 'uuid': kwargs.get('uuid', 'cb3d0b58-4415-4d90-b75b-1e96878730e3'),
'name': kwargs.get('name', 'TEST'), 'name': kwargs.get('name', 'TEST'),
@@ -208,9 +208,16 @@ def get_test_strategy(**kwargs):
'created_at': kwargs.get('created_at'), 'created_at': kwargs.get('created_at'),
'updated_at': kwargs.get('updated_at'), 'updated_at': kwargs.get('updated_at'),
'deleted_at': kwargs.get('deleted_at'), 'deleted_at': kwargs.get('deleted_at'),
'parameters_spec': kwargs.get('parameters_spec', {}) 'parameters_spec': kwargs.get('parameters_spec', {}),
} }
# goal ObjectField doesn't allow None nor dict, so if we want to simulate a
# non-eager object loading, the field should not be referenced at all.
if kwargs.get('goal'):
strategy_data['goal'] = kwargs.get('goal')
return strategy_data
def get_test_service(**kwargs): def get_test_service(**kwargs):
return { return {

View File

@@ -159,8 +159,11 @@ class TestDefaultPlanner(base.DbTestCase):
'migrate': 3 'migrate': 3
} }
obj_utils.create_test_audit_template(self.context) self.goal = obj_utils.create_test_goal(self.context)
self.strategy = obj_utils.create_test_strategy(self.context) self.strategy = obj_utils.create_test_strategy(
self.context, goal_id=self.goal.id)
obj_utils.create_test_audit_template(
self.context, goal_id=self.goal.id, strategy_id=self.strategy.id)
p = mock.patch.object(db_api.BaseConnection, 'create_action_plan') p = mock.patch.object(db_api.BaseConnection, 'create_action_plan')
self.mock_create_action_plan = p.start() self.mock_create_action_plan = p.start()
@@ -185,7 +188,8 @@ class TestDefaultPlanner(base.DbTestCase):
@mock.patch.object(objects.Strategy, 'get_by_name') @mock.patch.object(objects.Strategy, 'get_by_name')
def test_schedule_scheduled_empty(self, m_get_by_name): def test_schedule_scheduled_empty(self, m_get_by_name):
m_get_by_name.return_value = self.strategy m_get_by_name.return_value = self.strategy
audit = db_utils.create_test_audit(uuid=utils.generate_uuid()) audit = db_utils.create_test_audit(
goal_id=self.goal.id, strategy_id=self.strategy.id)
fake_solution = SolutionFakerSingleHyp.build() fake_solution = SolutionFakerSingleHyp.build()
action_plan = self.default_planner.schedule(self.context, action_plan = self.default_planner.schedule(self.context,
audit.id, fake_solution) audit.id, fake_solution)
@@ -194,8 +198,9 @@ class TestDefaultPlanner(base.DbTestCase):
@mock.patch.object(objects.Strategy, 'get_by_name') @mock.patch.object(objects.Strategy, 'get_by_name')
def test_scheduler_warning_empty_action_plan(self, m_get_by_name): def test_scheduler_warning_empty_action_plan(self, m_get_by_name):
m_get_by_name.return_value = self.strategy m_get_by_name.return_value = self.strategy
audit = db_utils.create_test_audit(uuid=utils.generate_uuid()) audit = db_utils.create_test_audit(
goal_id=self.goal.id, strategy_id=self.strategy.id)
fake_solution = SolutionFaker.build() fake_solution = SolutionFaker.build()
action_plan = self.default_planner.schedule(self.context, action_plan = self.default_planner.schedule(
audit.id, fake_solution) self.context, audit.id, fake_solution)
self.assertIsNotNone(action_plan.uuid) self.assertIsNotNone(action_plan.uuid)

View File

@@ -410,7 +410,7 @@ class TestObject(_LocalTest, _TestObject):
# The fingerprint values should only be changed if there is a version bump. # The fingerprint values should only be changed if there is a version bump.
expected_object_fingerprints = { expected_object_fingerprints = {
'Goal': '1.0-93881622db05e7b67a65ca885b4a022e', 'Goal': '1.0-93881622db05e7b67a65ca885b4a022e',
'Strategy': '1.0-e60f62cc854c6e63fb1c3befbfc8629e', 'Strategy': '1.1-73f164491bdd4c034f48083a51bdeb7b',
'AuditTemplate': '1.0-7432ee4d3ce0c7cbb9d11a4565ee8eb6', 'AuditTemplate': '1.0-7432ee4d3ce0c7cbb9d11a4565ee8eb6',
'Audit': '1.0-ebfc5360d019baf583a10a8a27071c97', 'Audit': '1.0-ebfc5360d019baf583a10a8a27071c97',
'ActionPlan': '1.0-cc76fd7f0e8479aeff817dd266341de4', 'ActionPlan': '1.0-cc76fd7f0e8479aeff817dd266341de4',

View File

@@ -15,7 +15,9 @@
import mock import mock
from watcher.common import exception from watcher.common import exception
from watcher.db.sqlalchemy import api as db_api
from watcher import objects from watcher import objects
from watcher.tests.db import base from watcher.tests.db import base
from watcher.tests.db import utils from watcher.tests.db import utils
@@ -23,107 +25,143 @@ from watcher.tests.db import utils
class TestStrategyObject(base.DbTestCase): class TestStrategyObject(base.DbTestCase):
goal_id = 2
scenarios = [
('non_eager', dict(
eager=False, fake_strategy=utils.get_test_strategy(
goal_id=goal_id))),
('eager_with_non_eager_load', dict(
eager=True, fake_strategy=utils.get_test_strategy(
goal_id=goal_id))),
('eager_with_eager_load', dict(
eager=True, fake_strategy=utils.get_test_strategy(
goal_id=goal_id, goal=utils.get_test_goal(id=goal_id)))),
]
def setUp(self): def setUp(self):
super(TestStrategyObject, self).setUp() super(TestStrategyObject, self).setUp()
self.fake_strategy = utils.get_test_strategy() self.fake_goal = utils.create_test_goal(id=self.goal_id)
def test_get_by_id(self): def eager_load_strategy_assert(self, strategy):
if self.eager:
self.assertIsNotNone(strategy.goal)
fields_to_check = set(
super(objects.Goal, objects.Goal).fields
).symmetric_difference(objects.Goal.fields)
db_data = {
k: v for k, v in self.fake_goal.as_dict().items()
if k in fields_to_check}
object_data = {
k: v for k, v in strategy.goal.as_dict().items()
if k in fields_to_check}
self.assertEqual(db_data, object_data)
@mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_get_by_id(self, mock_get_strategy):
strategy_id = self.fake_strategy['id'] strategy_id = self.fake_strategy['id']
with mock.patch.object(self.dbapi, 'get_strategy_by_id', mock_get_strategy.return_value = self.fake_strategy
autospec=True) as mock_get_strategy: strategy = objects.Strategy.get(
mock_get_strategy.return_value = self.fake_strategy self.context, strategy_id, eager=self.eager)
strategy = objects.Strategy.get(self.context, strategy_id) mock_get_strategy.assert_called_once_with(
mock_get_strategy.assert_called_once_with(self.context, self.context, strategy_id, eager=self.eager)
strategy_id) self.assertEqual(self.context, strategy._context)
self.assertEqual(self.context, strategy._context) self.eager_load_strategy_assert(strategy)
def test_get_by_uuid(self): @mock.patch.object(db_api.Connection, 'get_strategy_by_uuid')
def test_get_by_uuid(self, mock_get_strategy):
uuid = self.fake_strategy['uuid'] uuid = self.fake_strategy['uuid']
with mock.patch.object(self.dbapi, 'get_strategy_by_uuid', mock_get_strategy.return_value = self.fake_strategy
autospec=True) as mock_get_strategy: strategy = objects.Strategy.get(self.context, uuid, eager=self.eager)
mock_get_strategy.return_value = self.fake_strategy mock_get_strategy.assert_called_once_with(
strategy = objects.Strategy.get(self.context, uuid) self.context, uuid, eager=self.eager)
mock_get_strategy.assert_called_once_with(self.context, uuid) self.assertEqual(self.context, strategy._context)
self.assertEqual(self.context, strategy._context) self.eager_load_strategy_assert(strategy)
def test_get_bad_uuid(self): def test_get_bad_uuid(self):
self.assertRaises(exception.InvalidIdentity, self.assertRaises(exception.InvalidIdentity,
objects.Strategy.get, self.context, 'not-a-uuid') objects.Strategy.get, self.context, 'not-a-uuid')
def test_list(self): @mock.patch.object(db_api.Connection, 'get_strategy_list')
with mock.patch.object(self.dbapi, 'get_strategy_list', def test_list(self, mock_get_list):
autospec=True) as mock_get_list: mock_get_list.return_value = [self.fake_strategy]
mock_get_list.return_value = [self.fake_strategy] strategies = objects.Strategy.list(self.context, eager=self.eager)
strategies = objects.Strategy.list(self.context) self.assertEqual(1, mock_get_list.call_count, 1)
self.assertEqual(1, mock_get_list.call_count, 1) self.assertEqual(1, len(strategies))
self.assertEqual(1, len(strategies)) self.assertIsInstance(strategies[0], objects.Strategy)
self.assertIsInstance(strategies[0], objects.Strategy) self.assertEqual(self.context, strategies[0]._context)
self.assertEqual(self.context, strategies[0]._context) for strategy in strategies:
self.eager_load_strategy_assert(strategy)
def test_create(self): @mock.patch.object(db_api.Connection, 'update_strategy')
with mock.patch.object(self.dbapi, 'create_strategy', @mock.patch.object(db_api.Connection, 'get_strategy_by_id')
autospec=True) as mock_create_strategy: def test_save(self, mock_get_strategy, mock_update_strategy):
mock_create_strategy.return_value = self.fake_strategy
strategy = objects.Strategy(self.context, **self.fake_strategy)
strategy.create()
mock_create_strategy.assert_called_once_with(self.fake_strategy)
self.assertEqual(self.context, strategy._context)
def test_destroy(self):
_id = self.fake_strategy['id'] _id = self.fake_strategy['id']
with mock.patch.object(self.dbapi, 'get_strategy_by_id', mock_get_strategy.return_value = self.fake_strategy
autospec=True) as mock_get_strategy: strategy = objects.Strategy.get_by_id(
mock_get_strategy.return_value = self.fake_strategy self.context, _id, eager=self.eager)
with mock.patch.object(self.dbapi, 'destroy_strategy', strategy.name = 'UPDATED NAME'
autospec=True) as mock_destroy_strategy: strategy.save()
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.destroy()
mock_get_strategy.assert_called_once_with(self.context, _id)
mock_destroy_strategy.assert_called_once_with(_id)
self.assertEqual(self.context, strategy._context)
def test_save(self): mock_get_strategy.assert_called_once_with(
_id = self.fake_strategy['id'] self.context, _id, eager=self.eager)
with mock.patch.object(self.dbapi, 'get_strategy_by_id', mock_update_strategy.assert_called_once_with(
autospec=True) as mock_get_strategy: _id, {'name': 'UPDATED NAME'})
mock_get_strategy.return_value = self.fake_strategy self.assertEqual(self.context, strategy._context)
with mock.patch.object(self.dbapi, 'update_strategy', self.eager_load_strategy_assert(strategy)
autospec=True) as mock_update_strategy:
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.name = 'UPDATED NAME'
strategy.save()
mock_get_strategy.assert_called_once_with(self.context, _id) @mock.patch.object(db_api.Connection, 'get_strategy_by_id')
mock_update_strategy.assert_called_once_with( def test_refresh(self, mock_get_strategy):
_id, {'name': 'UPDATED NAME'})
self.assertEqual(self.context, strategy._context)
def test_refresh(self):
_id = self.fake_strategy['id'] _id = self.fake_strategy['id']
returns = [dict(self.fake_strategy, name="first name"), returns = [dict(self.fake_strategy, name="first name"),
dict(self.fake_strategy, name="second name")] dict(self.fake_strategy, name="second name")]
expected = [mock.call(self.context, _id), mock_get_strategy.side_effect = returns
mock.call(self.context, _id)] expected = [mock.call(self.context, _id, eager=self.eager),
with mock.patch.object(self.dbapi, 'get_strategy_by_id', mock.call(self.context, _id, eager=self.eager)]
side_effect=returns, strategy = objects.Strategy.get(self.context, _id, eager=self.eager)
autospec=True) as mock_get_strategy: self.assertEqual("first name", strategy.name)
strategy = objects.Strategy.get(self.context, _id) strategy.refresh(eager=self.eager)
self.assertEqual("first name", strategy.name) self.assertEqual("second name", strategy.name)
strategy.refresh() self.assertEqual(expected, mock_get_strategy.call_args_list)
self.assertEqual("second name", strategy.name) self.assertEqual(self.context, strategy._context)
self.assertEqual(expected, mock_get_strategy.call_args_list) self.eager_load_strategy_assert(strategy)
self.assertEqual(self.context, strategy._context)
def test_soft_delete(self):
class TestCreateDeleteStrategyObject(base.DbTestCase):
def setUp(self):
super(TestCreateDeleteStrategyObject, self).setUp()
self.fake_goal = utils.create_test_goal()
self.fake_strategy = utils.get_test_strategy(goal_id=self.fake_goal.id)
@mock.patch.object(db_api.Connection, 'create_strategy')
def test_create(self, mock_create_strategy):
mock_create_strategy.return_value = self.fake_strategy
strategy = objects.Strategy(self.context, **self.fake_strategy)
strategy.create()
mock_create_strategy.assert_called_once_with(self.fake_strategy)
self.assertEqual(self.context, strategy._context)
@mock.patch.object(db_api.Connection, 'soft_delete_strategy')
@mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_soft_delete(self, mock_get_strategy, mock_soft_delete):
_id = self.fake_strategy['id'] _id = self.fake_strategy['id']
with mock.patch.object(self.dbapi, 'get_strategy_by_id', mock_get_strategy.return_value = self.fake_strategy
autospec=True) as mock_get_strategy: strategy = objects.Strategy.get_by_id(self.context, _id)
mock_get_strategy.return_value = self.fake_strategy strategy.soft_delete()
with mock.patch.object(self.dbapi, 'soft_delete_strategy', mock_get_strategy.assert_called_once_with(
autospec=True) as mock_soft_delete: self.context, _id, eager=False)
strategy = objects.Strategy.get_by_id(self.context, _id) mock_soft_delete.assert_called_once_with(_id)
strategy.soft_delete() self.assertEqual(self.context, strategy._context)
mock_get_strategy.assert_called_once_with(self.context, _id)
mock_soft_delete.assert_called_once_with(_id) @mock.patch.object(db_api.Connection, 'destroy_strategy')
self.assertEqual(self.context, strategy._context) @mock.patch.object(db_api.Connection, 'get_strategy_by_id')
def test_destroy(self, mock_get_strategy, mock_destroy_strategy):
_id = self.fake_strategy['id']
mock_get_strategy.return_value = self.fake_strategy
strategy = objects.Strategy.get_by_id(self.context, _id)
strategy.destroy()
mock_get_strategy.assert_called_once_with(
self.context, _id, eager=False)
mock_destroy_strategy.assert_called_once_with(_id)
self.assertEqual(self.context, strategy._context)

View File

@@ -217,6 +217,14 @@ def get_test_strategy(context, **kw):
strategy = objects.Strategy(context) strategy = objects.Strategy(context)
for key in db_strategy: for key in db_strategy:
setattr(strategy, key, db_strategy[key]) setattr(strategy, key, db_strategy[key])
# ObjectField checks for the object type, so if we want to simulate a
# non-eager object loading, the field should not be referenced at all.
# Contrarily, eager loading need the data to be casted to the object type
# that was specified by the ObjectField.
if kw.get('goal'):
strategy.goal = objects.Goal(context, **kw.get('goal'))
return strategy return strategy