diff --git a/include/bt.hpp b/include/bt.hpp index a9ae959..7d7d957 100644 --- a/include/bt.hpp +++ b/include/bt.hpp @@ -2,224 +2,275 @@ #include #include +#include +#include #include #include -#include #include -#include -enum class task_result { - failure, - success -}; +namespace bt { +enum class task_result { failure, success }; -class node_task -{ +class node_task { public: - class promise_type; + class promise_type; - using handle_type = std::coroutine_handle; + using handle_type = std::coroutine_handle; - class promise_type { - public: - std::optional result; - handle_type get_return_object() noexcept { - return handle_type::from_promise(*this); - } - - std::suspend_always initial_suspend() noexcept { return {}; } - - std::suspend_always final_suspend() noexcept { return {}; } - - void return_value(task_result r) noexcept { - result = r; - } - - void unhandled_exception() noexcept { - std::terminate(); // TODO: handle exception - } - }; - - class awaiter { - public: - awaiter(handle_type handle) noexcept : _handle(handle) {} - - bool await_ready() const noexcept { - return false; - } - - auto await_suspend(std::coroutine_handle<>) noexcept { - return _handle; - } - - task_result await_resume() noexcept { - return *_handle.promise().result; - } - - private: - handle_type _handle; - }; - - node_task(handle_type handle) noexcept : _handle(handle) {} - - node_task(node_task&& other) noexcept : _handle(other._handle) { - other._handle = nullptr; + class promise_type { + public: + std::optional result; + handle_type get_return_object() noexcept { + return handle_type::from_promise(*this); } - ~node_task() { - if (_handle) { - _handle.destroy(); - } - } + std::suspend_always initial_suspend() noexcept { return {}; } - node_task& operator=(node_task&& other) noexcept { - if (&other != this) { - if (_handle) { - _handle.destroy(); - } - _handle = other._handle; - other._handle = nullptr; - } - return *this; - } + std::suspend_always final_suspend() noexcept { return {}; } - awaiter operator co_await() noexcept { - return awaiter{_handle}; - } + void return_value(task_result r) noexcept { result = r; } - void resume() noexcept { - _handle.resume(); + void unhandled_exception() noexcept { + std::terminate(); // TODO: handle exception } + }; - bool done() const noexcept { - return _handle.done(); - } + class awaiter { + public: + awaiter(handle_type handle) noexcept : _handle(handle) {} - task_result result() noexcept { - return *_handle.promise().result; - } + bool await_ready() const noexcept { return false; } - std::optional try_result() noexcept { - return _handle.promise().result; + auto await_suspend(std::coroutine_handle<>) noexcept { return _handle; } + + task_result await_resume() noexcept { return *_handle.promise().result; } + + private: + handle_type _handle; + }; + + node_task(handle_type handle) noexcept : _handle(handle) {} + + node_task(node_task &&other) noexcept : _handle(other._handle) { + other._handle = nullptr; + } + + ~node_task() { + if (_handle) { + _handle.destroy(); } + } + + node_task &operator=(node_task &&other) noexcept { + if (&other != this) { + if (_handle) { + _handle.destroy(); + } + _handle = other._handle; + other._handle = nullptr; + } + return *this; + } + + awaiter operator co_await() noexcept { return awaiter{_handle}; } + + void resume() noexcept { _handle.resume(); } + + bool done() const noexcept { return _handle.done(); } + + task_result result() noexcept { return *_handle.promise().result; } + + std::optional try_result() noexcept { + return _handle.promise().result; + } private: - handle_type _handle; + handle_type _handle; }; using yield = std::suspend_always; -template -class behavior_node { +template class behavior_node { public: - virtual node_task tick(Context& ctx) noexcept = 0; + virtual node_task tick(Context &ctx) noexcept = 0; - virtual ~behavior_node() = default; + virtual ~behavior_node() = default; }; -template +template class sequence_node : public behavior_node { public: - void add_child(std::unique_ptr>&& child) { - _children.push_back(std::move(child)); + void add_child(std::unique_ptr> &&child) { + _children.push_back(std::move(child)); + } + + node_task tick(Context &ctx) noexcept override { + std::size_t idx = 0; + while (idx < _children.size()) { + auto task = _children[idx]->tick(ctx); + while (!task.done()) { + task.resume(); + + if (idx < _children.size() - 1) + co_await yield{}; + } + + if (task.result() == task_result::failure) { + co_return task_result::failure; + } + + ++idx; } - node_task tick(Context& ctx) noexcept override { - std::size_t idx = 0; - while (idx < _children.size()) { - auto task = _children[idx]->tick(ctx); - while (!task.done()) { - task.resume(); + co_return task_result::success; + } - if (idx < _children.size() - 1) - co_await yield{}; - } - - if (task.result() == task_result::failure) { - co_return task_result::failure; - } - - ++idx; - } - - co_return task_result::success; - } - private: - std::vector>> _children; + std::vector>> _children; }; -template -auto sequence() { - return std::make_unique>(); +template auto sequence() { + return std::make_unique>(); } -template +template class selector_node : public behavior_node { public: - void add_child(std::unique_ptr>&& child) { - _children.push_back(std::move(child)); - } + void add_child(std::unique_ptr> &&child) { + _children.push_back(std::move(child)); + } - node_task tick(Context& ctx) noexcept override { - for (std::size_t idx = 0; idx < _children.size(); ++idx) { - auto task = _children[idx]->tick(ctx); - while (!task.done()) { - task.resume(); - co_await yield{}; - } + node_task tick(Context &ctx) noexcept override { + for (std::size_t idx = 0; idx < _children.size(); ++idx) { + auto task = _children[idx]->tick(ctx); + while (!task.done()) { + task.resume(); + co_await yield{}; + } - if (task.result() == task_result::success) { - co_return task_result::success; - } - } - co_return task_result::failure; + if (task.result() == task_result::success) { + co_return task_result::success; + } } + co_return task_result::failure; + } private: - std::vector>> _children; + std::vector>> _children; }; -template -auto selector() { - return std::make_unique>(); +template auto selector() { + return std::make_unique>(); } -template +template class action_node : public behavior_node { public: - action_node(TickFn tick) noexcept : _tick(tick) {} + action_node(TickFn tick) noexcept : _tick(tick) {} - node_task tick(Context& ctx) noexcept override { - return _tick(ctx); - } + node_task tick(Context &ctx) noexcept override { return _tick(ctx); } private: - TickFn _tick; + TickFn _tick; }; -template -auto action(TickFn tick) { - return std::make_unique>(tick); +template auto action(TickFn tick) { + return std::make_unique>(tick); } -template +template class conditional_node : public behavior_node { public: - conditional_node(ConditionFn condition) noexcept : _condition(condition) {} + conditional_node(ConditionFn condition) noexcept : _condition(condition) {} - node_task tick(Context& ctx) noexcept override { - if (_condition(ctx)) { - co_return task_result::success; - } - co_return task_result::failure; + node_task tick(Context &ctx) noexcept override { + if (_condition(ctx)) { + co_return task_result::success; } + co_return task_result::failure; + } private: - ConditionFn _condition; + ConditionFn _condition; }; -template +template auto conditional(ConditionFn condition) { - return std::make_unique>(condition); -} \ No newline at end of file + return std::make_unique>(condition); +} + +template class tree_builder { +public: + tree_builder() noexcept : _root(nullptr) {} + + tree_builder &begin_sequence() { + auto sequence = std::make_unique>(); + node_stack.push_back(sequence.get()); + + if (!_root) { + _root = std::move(sequence); + } else { + node_stack[node_stack.size() - 2]->add_child(std::move(sequence)); + } + + return *this; + } + + tree_builder &end_sequence() { + if (!node_stack.empty() && + dynamic_cast *>(node_stack.back()) != nullptr) { + node_stack.pop_back(); + } + + return *this; + } + + tree_builder &begin_selector() { + auto selector = std::make_unique>(); + node_stack.push_back(selector.get()); + + if (!_root) { + _root = std::move(selector); + } else { + node_stack[node_stack.size() - 2]->add_child(std::move(selector)); + } + + return *this; + } + + tree_builder &end_selector() { + if (!node_stack.empty() && + dynamic_cast *>(node_stack.back()) != nullptr) { + node_stack.pop_back(); + } + + return *this; + } + + template tree_builder &action(TickFn tick) { + if (node_stack.empty()) { + throw std::runtime_error("Can't add an action outside a composite node."); + } + node_stack.back()->add_child( + std::make_unique>(tick)); + return *this; + } + + template + tree_builder &conditional(ConditionFn condition) { + if (node_stack.empty()) { + throw std::runtime_error( + "Can't add a conditional outside a composite node."); + } + node_stack.back()->add_child( + std::make_unique>(condition)); + return *this; + } + + std::unique_ptr> build() { return std::move(_root); } + +private: + std::unique_ptr> _root; + std::vector *> node_stack; +}; +} // namespace bt