diff --git a/watcher/decision_engine/model/model_root.py b/watcher/decision_engine/model/model_root.py index 1845638de..37b9e2296 100644 --- a/watcher/decision_engine/model/model_root.py +++ b/watcher/decision_engine/model/model_root.py @@ -149,6 +149,19 @@ class ModelRoot(nx.DiGraph, base.Model): except exception.ComputeResourceNotFound: raise exception.ComputeNodeNotFound(name=uuid) + @lockutils.synchronized("model_root") + def get_node_by_name(self, name): + try: + node_list = [cn['attr'] for uuid, cn in self.nodes(data=True) + if (isinstance(cn['attr'], element.ComputeNode) and + cn['attr']['hostname'] == name)] + if node_list: + return node_list[0] + else: + raise exception.ComputeResourceNotFound + except exception.ComputeResourceNotFound: + raise exception.ComputeNodeNotFound(name=name) + @lockutils.synchronized("model_root") def get_instance_by_uuid(self, uuid): try: diff --git a/watcher/tests/decision_engine/model/test_model.py b/watcher/tests/decision_engine/model/test_model.py index 611aff6af..27016ee14 100644 --- a/watcher/tests/decision_engine/model/test_model.py +++ b/watcher/tests/decision_engine/model/test_model.py @@ -122,6 +122,32 @@ class TestModel(base.TestCase): node.state = element.ServiceState.OFFLINE.value self.assertIn(node.state, [el.value for el in element.ServiceState]) + def test_get_node_by_name(self): + model = model_root.ModelRoot() + uuid_ = "{0}".format(uuidutils.generate_uuid()) + name = 'test_node' + node = element.ComputeNode() + node.uuid = uuid_ + node.hostname = name + model.add_node(node) + compute_node = model.get_node_by_name(name) + model.assert_node(compute_node) + self.assertEqual(name, compute_node['hostname']) + self.assertEqual(uuid_, compute_node['uuid']) + + def test_node_from_name_raise(self): + model = model_root.ModelRoot() + uuid_ = "{0}".format(uuidutils.generate_uuid()) + name = 'test_node' + node = element.ComputeNode() + node.uuid = uuid_ + node.hostname = name + model.add_node(node) + + fake_name = 'fake_node' + self.assertRaises(exception.ComputeNodeNotFound, + model.get_node_by_name, fake_name) + def test_node_from_uuid_raise(self): model = model_root.ModelRoot() uuid_ = "{0}".format(uuidutils.generate_uuid())