Fixed tree builder

master
Warren Ulrich 2023-06-12 00:08:04 -07:00
parent 2c6c5a2e74
commit c19dfa7c06
1 changed files with 70 additions and 52 deletions

View File

@ -98,20 +98,27 @@ public:
}; };
template <typename Context> template <typename Context>
class sequence_node : public behavior_node<Context> { class composite_node : public behavior_node<Context> {
public: public:
void add_child(std::unique_ptr<behavior_node<Context>> &&child) { void add_child(std::unique_ptr<behavior_node<Context>> &&child) {
_children.push_back(std::move(child)); _children.push_back(std::move(child));
} }
protected:
std::vector<std::unique_ptr<behavior_node<Context>>> _children;
};
template <typename Context>
class sequence_node : public composite_node<Context> {
public:
node_task tick(Context &ctx) noexcept override { node_task tick(Context &ctx) noexcept override {
std::size_t idx = 0; std::size_t idx = 0;
while (idx < _children.size()) { while (idx < this->_children.size()) {
auto task = _children[idx]->tick(ctx); auto task = this->_children[idx]->tick(ctx);
while (!task.done()) { while (!task.done()) {
task.resume(); task.resume();
if (idx < _children.size() - 1) if (idx < this->_children.size() - 1)
co_await yield{}; co_await yield{};
} }
@ -124,9 +131,6 @@ public:
co_return task_result::success; co_return task_result::success;
} }
private:
std::vector<std::unique_ptr<behavior_node<Context>>> _children;
}; };
template <typename Context> auto sequence() { template <typename Context> auto sequence() {
@ -134,15 +138,11 @@ template <typename Context> auto sequence() {
} }
template <typename Context> template <typename Context>
class selector_node : public behavior_node<Context> { class selector_node : public composite_node<Context> {
public: public:
void add_child(std::unique_ptr<behavior_node<Context>> &&child) {
_children.push_back(std::move(child));
}
node_task tick(Context &ctx) noexcept override { node_task tick(Context &ctx) noexcept override {
for (std::size_t idx = 0; idx < _children.size(); ++idx) { for (std::size_t idx = 0; idx < this->_children.size(); ++idx) {
auto task = _children[idx]->tick(ctx); auto task = this->_children[idx]->tick(ctx);
while (!task.done()) { while (!task.done()) {
task.resume(); task.resume();
co_await yield{}; co_await yield{};
@ -154,9 +154,6 @@ public:
} }
co_return task_result::failure; co_return task_result::failure;
} }
private:
std::vector<std::unique_ptr<behavior_node<Context>>> _children;
}; };
template <typename Context> auto selector() { template <typename Context> auto selector() {
@ -204,73 +201,94 @@ public:
tree_builder() noexcept : _root(nullptr) {} tree_builder() noexcept : _root(nullptr) {}
tree_builder &begin_sequence() { tree_builder &begin_sequence() {
auto sequence = std::make_unique<sequence_node<Context>>(); _node_stack.push_back(std::make_unique<sequence_node<Context>>());
node_stack.push_back(sequence.get()); _children.emplace_back();
if (!_root) {
_root = std::move(sequence);
} else {
node_stack[node_stack.size() - 2]->add_child(std::move(sequence));
}
return *this; return *this;
} }
tree_builder &end_sequence() { tree_builder &end_sequence() {
if (!node_stack.empty() && if (_node_stack.empty()) {
dynamic_cast<sequence_node<Context> *>(node_stack.back()) != nullptr) { throw std::runtime_error("Mismatched begin_sequence/end_sequence");
node_stack.pop_back(); }
auto node = std::move(_node_stack.back());
_node_stack.pop_back();
auto children = std::move(_children.back());
_children.pop_back();
for (auto &&child : children) {
static_cast<composite_node<Context> *>(node.get())->add_child(std::move(child));
}
if (_node_stack.empty()) {
_root = std::move(node);
} else {
_children.back().push_back(std::move(node));
} }
return *this; return *this;
} }
tree_builder &begin_selector() { tree_builder &begin_selector() {
auto selector = std::make_unique<selector_node<Context>>(); _node_stack.push_back(std::make_unique<selector_node<Context>>());
node_stack.push_back(selector.get()); _children.emplace_back();
if (!_root) {
_root = std::move(selector);
} else {
node_stack[node_stack.size() - 2]->add_child(std::move(selector));
}
return *this; return *this;
} }
tree_builder &end_selector() { tree_builder &end_selector() {
if (!node_stack.empty() && if (_node_stack.empty()) {
dynamic_cast<selector_node<Context> *>(node_stack.back()) != nullptr) { throw std::runtime_error("Mismatched begin_selector/end_selector");
node_stack.pop_back(); }
auto node = std::move(_node_stack.back());
_node_stack.pop_back();
auto children = std::move(_children.back());
_children.pop_back();
for (auto &&child : children) {
static_cast<composite_node<Context> *>(node.get())->add_child(std::move(child));
}
if (_node_stack.empty()) {
_root = std::move(node);
} else {
_children.back().push_back(std::move(node));
} }
return *this; return *this;
} }
template <typename TickFn> tree_builder &action(TickFn tick) { template <typename TickFn> tree_builder &action(TickFn tick) {
if (node_stack.empty()) { if (_node_stack.empty()) {
throw std::runtime_error("Can't add an action outside a composite node."); throw std::runtime_error("Action must be within a sequence or selector");
} }
node_stack.back()->add_child(
std::make_unique<action_node<Context, TickFn>>(tick)); _children.back().push_back(std::make_unique<action_node<Context, TickFn>>(tick));
return *this; return *this;
} }
template <typename ConditionFn> template <typename ConditionFn>
tree_builder &conditional(ConditionFn condition) { tree_builder &conditional(ConditionFn condition) {
if (node_stack.empty()) { if (_node_stack.empty()) {
throw std::runtime_error( throw std::runtime_error("Condition must be within a sequence or selector");
"Can't add a conditional outside a composite node.");
} }
node_stack.back()->add_child(
std::make_unique<conditional_node<Context, ConditionFn>>(condition)); _children.back().push_back(std::make_unique<conditional_node<Context, ConditionFn>>(condition));
return *this; return *this;
} }
std::unique_ptr<behavior_node<Context>> build() { return std::move(_root); } std::unique_ptr<behavior_node<Context>> build() {
if (!_node_stack.empty()) {
throw std::runtime_error("Mismatched begin/end");
}
return std::move(_root);
}
private: private:
std::unique_ptr<behavior_node<Context>> _root; std::unique_ptr<behavior_node<Context>> _root;
std::vector<behavior_node<Context> *> node_stack; std::vector<std::unique_ptr<behavior_node<Context>>> _node_stack;
std::vector<std::vector<std::unique_ptr<behavior_node<Context>>>> _children;
}; };
} // namespace bt } // namespace bt