tree_builder helper class, formatting consistency

master
Warren Ulrich 2023-06-11 23:23:08 -07:00
parent 0ef55377b2
commit 2c6c5a2e74
1 changed files with 205 additions and 154 deletions

View File

@ -2,19 +2,16 @@
#include <coroutine> #include <coroutine>
#include <iostream> #include <iostream>
#include <memory>
#include <optional>
#include <stdexcept> #include <stdexcept>
#include <utility> #include <utility>
#include <optional>
#include <vector> #include <vector>
#include <memory>
enum class task_result { namespace bt {
failure, enum class task_result { failure, success };
success
};
class node_task class node_task {
{
public: public:
class promise_type; class promise_type;
@ -31,9 +28,7 @@ public:
std::suspend_always final_suspend() noexcept { return {}; } std::suspend_always final_suspend() noexcept { return {}; }
void return_value(task_result r) noexcept { void return_value(task_result r) noexcept { result = r; }
result = r;
}
void unhandled_exception() noexcept { void unhandled_exception() noexcept {
std::terminate(); // TODO: handle exception std::terminate(); // TODO: handle exception
@ -44,17 +39,11 @@ public:
public: public:
awaiter(handle_type handle) noexcept : _handle(handle) {} awaiter(handle_type handle) noexcept : _handle(handle) {}
bool await_ready() const noexcept { bool await_ready() const noexcept { return false; }
return false;
}
auto await_suspend(std::coroutine_handle<>) noexcept { auto await_suspend(std::coroutine_handle<>) noexcept { return _handle; }
return _handle;
}
task_result await_resume() noexcept { task_result await_resume() noexcept { return *_handle.promise().result; }
return *_handle.promise().result;
}
private: private:
handle_type _handle; handle_type _handle;
@ -83,21 +72,13 @@ public:
return *this; return *this;
} }
awaiter operator co_await() noexcept { awaiter operator co_await() noexcept { return awaiter{_handle}; }
return awaiter{_handle};
}
void resume() noexcept { void resume() noexcept { _handle.resume(); }
_handle.resume();
}
bool done() const noexcept { bool done() const noexcept { return _handle.done(); }
return _handle.done();
}
task_result result() noexcept { task_result result() noexcept { return *_handle.promise().result; }
return *_handle.promise().result;
}
std::optional<task_result> try_result() noexcept { std::optional<task_result> try_result() noexcept {
return _handle.promise().result; return _handle.promise().result;
@ -109,8 +90,7 @@ private:
using yield = std::suspend_always; using yield = std::suspend_always;
template<typename Context> template <typename Context> class behavior_node {
class behavior_node {
public: public:
virtual node_task tick(Context &ctx) noexcept = 0; virtual node_task tick(Context &ctx) noexcept = 0;
@ -149,8 +129,7 @@ private:
std::vector<std::unique_ptr<behavior_node<Context>>> _children; std::vector<std::unique_ptr<behavior_node<Context>>> _children;
}; };
template<typename Context> template <typename Context> auto sequence() {
auto sequence() {
return std::make_unique<sequence_node<Context>>(); return std::make_unique<sequence_node<Context>>();
} }
@ -180,8 +159,7 @@ private:
std::vector<std::unique_ptr<behavior_node<Context>>> _children; std::vector<std::unique_ptr<behavior_node<Context>>> _children;
}; };
template<typename Context> template <typename Context> auto selector() {
auto selector() {
return std::make_unique<selector_node<Context>>(); return std::make_unique<selector_node<Context>>();
} }
@ -190,16 +168,13 @@ class action_node : public behavior_node<Context> {
public: public:
action_node(TickFn tick) noexcept : _tick(tick) {} action_node(TickFn tick) noexcept : _tick(tick) {}
node_task tick(Context& ctx) noexcept override { node_task tick(Context &ctx) noexcept override { return _tick(ctx); }
return _tick(ctx);
}
private: private:
TickFn _tick; TickFn _tick;
}; };
template<typename Context, typename TickFn> template <typename Context, typename TickFn> auto action(TickFn tick) {
auto action(TickFn tick) {
return std::make_unique<action_node<Context, TickFn>>(tick); return std::make_unique<action_node<Context, TickFn>>(tick);
} }
@ -223,3 +198,79 @@ template<typename Context, typename ConditionFn>
auto conditional(ConditionFn condition) { auto conditional(ConditionFn condition) {
return std::make_unique<conditional_node<Context, ConditionFn>>(condition); return std::make_unique<conditional_node<Context, ConditionFn>>(condition);
} }
template <typename Context> class tree_builder {
public:
tree_builder() noexcept : _root(nullptr) {}
tree_builder &begin_sequence() {
auto sequence = std::make_unique<sequence_node<Context>>();
node_stack.push_back(sequence.get());
if (!_root) {
_root = std::move(sequence);
} else {
node_stack[node_stack.size() - 2]->add_child(std::move(sequence));
}
return *this;
}
tree_builder &end_sequence() {
if (!node_stack.empty() &&
dynamic_cast<sequence_node<Context> *>(node_stack.back()) != nullptr) {
node_stack.pop_back();
}
return *this;
}
tree_builder &begin_selector() {
auto selector = std::make_unique<selector_node<Context>>();
node_stack.push_back(selector.get());
if (!_root) {
_root = std::move(selector);
} else {
node_stack[node_stack.size() - 2]->add_child(std::move(selector));
}
return *this;
}
tree_builder &end_selector() {
if (!node_stack.empty() &&
dynamic_cast<selector_node<Context> *>(node_stack.back()) != nullptr) {
node_stack.pop_back();
}
return *this;
}
template <typename TickFn> tree_builder &action(TickFn tick) {
if (node_stack.empty()) {
throw std::runtime_error("Can't add an action outside a composite node.");
}
node_stack.back()->add_child(
std::make_unique<action_node<Context, TickFn>>(tick));
return *this;
}
template <typename ConditionFn>
tree_builder &conditional(ConditionFn condition) {
if (node_stack.empty()) {
throw std::runtime_error(
"Can't add a conditional outside a composite node.");
}
node_stack.back()->add_child(
std::make_unique<conditional_node<Context, ConditionFn>>(condition));
return *this;
}
std::unique_ptr<behavior_node<Context>> build() { return std::move(_root); }
private:
std::unique_ptr<behavior_node<Context>> _root;
std::vector<behavior_node<Context> *> node_stack;
};
} // namespace bt