#pragma once #include #include #include #include #include #include #include namespace bt { enum class task_result { failure, success }; class node_task { public: class promise_type; using handle_type = std::coroutine_handle; class promise_type { public: std::optional result; handle_type get_return_object() noexcept { return handle_type::from_promise(*this); } std::suspend_always initial_suspend() noexcept { return {}; } std::suspend_always final_suspend() noexcept { return {}; } void return_value(task_result r) noexcept { result = r; } void unhandled_exception() noexcept { std::terminate(); // TODO: handle exception } }; class awaiter { public: awaiter(handle_type handle) noexcept : _handle(handle) {} bool await_ready() const noexcept { return false; } auto await_suspend(std::coroutine_handle<>) noexcept { return _handle; } task_result await_resume() noexcept { return *_handle.promise().result; } private: handle_type _handle; }; node_task(handle_type handle) noexcept : _handle(handle) {} node_task(node_task &&other) noexcept : _handle(other._handle) { other._handle = nullptr; } ~node_task() { if (_handle) { _handle.destroy(); } } node_task &operator=(node_task &&other) noexcept { if (&other != this) { if (_handle) { _handle.destroy(); } _handle = other._handle; other._handle = nullptr; } return *this; } awaiter operator co_await() noexcept { return awaiter{_handle}; } void resume() noexcept { _handle.resume(); } bool done() const noexcept { return _handle.done(); } task_result result() noexcept { return *_handle.promise().result; } std::optional try_result() noexcept { return _handle.promise().result; } private: handle_type _handle; }; using yield = std::suspend_always; template class behavior_node { public: virtual node_task tick(Context &ctx) noexcept = 0; virtual ~behavior_node() = default; }; template class sequence_node : public behavior_node { public: void add_child(std::unique_ptr> &&child) { _children.push_back(std::move(child)); } node_task tick(Context &ctx) noexcept override { std::size_t idx = 0; while (idx < _children.size()) { auto task = _children[idx]->tick(ctx); while (!task.done()) { task.resume(); if (idx < _children.size() - 1) co_await yield{}; } if (task.result() == task_result::failure) { co_return task_result::failure; } ++idx; } co_return task_result::success; } private: std::vector>> _children; }; template auto sequence() { return std::make_unique>(); } template class selector_node : public behavior_node { public: void add_child(std::unique_ptr> &&child) { _children.push_back(std::move(child)); } node_task tick(Context &ctx) noexcept override { for (std::size_t idx = 0; idx < _children.size(); ++idx) { auto task = _children[idx]->tick(ctx); while (!task.done()) { task.resume(); co_await yield{}; } if (task.result() == task_result::success) { co_return task_result::success; } } co_return task_result::failure; } private: std::vector>> _children; }; template auto selector() { return std::make_unique>(); } template class action_node : public behavior_node { public: action_node(TickFn tick) noexcept : _tick(tick) {} node_task tick(Context &ctx) noexcept override { return _tick(ctx); } private: TickFn _tick; }; template auto action(TickFn tick) { return std::make_unique>(tick); } template class conditional_node : public behavior_node { public: conditional_node(ConditionFn condition) noexcept : _condition(condition) {} node_task tick(Context &ctx) noexcept override { if (_condition(ctx)) { co_return task_result::success; } co_return task_result::failure; } private: ConditionFn _condition; }; template auto conditional(ConditionFn condition) { return std::make_unique>(condition); } template class tree_builder { public: tree_builder() noexcept : _root(nullptr) {} tree_builder &begin_sequence() { auto sequence = std::make_unique>(); node_stack.push_back(sequence.get()); if (!_root) { _root = std::move(sequence); } else { node_stack[node_stack.size() - 2]->add_child(std::move(sequence)); } return *this; } tree_builder &end_sequence() { if (!node_stack.empty() && dynamic_cast *>(node_stack.back()) != nullptr) { node_stack.pop_back(); } return *this; } tree_builder &begin_selector() { auto selector = std::make_unique>(); node_stack.push_back(selector.get()); if (!_root) { _root = std::move(selector); } else { node_stack[node_stack.size() - 2]->add_child(std::move(selector)); } return *this; } tree_builder &end_selector() { if (!node_stack.empty() && dynamic_cast *>(node_stack.back()) != nullptr) { node_stack.pop_back(); } return *this; } template tree_builder &action(TickFn tick) { if (node_stack.empty()) { throw std::runtime_error("Can't add an action outside a composite node."); } node_stack.back()->add_child( std::make_unique>(tick)); return *this; } template tree_builder &conditional(ConditionFn condition) { if (node_stack.empty()) { throw std::runtime_error( "Can't add a conditional outside a composite node."); } node_stack.back()->add_child( std::make_unique>(condition)); return *this; } std::unique_ptr> build() { return std::move(_root); } private: std::unique_ptr> _root; std::vector *> node_stack; }; } // namespace bt