bt/include/bt.hpp

277 lines
6.6 KiB
C++
Raw Normal View History

2023-06-12 05:56:19 +00:00
#pragma once
#include <coroutine>
#include <iostream>
#include <memory>
#include <optional>
2023-06-12 05:56:19 +00:00
#include <stdexcept>
#include <utility>
#include <vector>
namespace bt {
enum class task_result { failure, success };
2023-06-12 05:56:19 +00:00
class node_task {
2023-06-12 05:56:19 +00:00
public:
class promise_type;
2023-06-12 05:56:19 +00:00
using handle_type = std::coroutine_handle<promise_type>;
2023-06-12 05:56:19 +00:00
class promise_type {
public:
std::optional<task_result> result;
handle_type get_return_object() noexcept {
return handle_type::from_promise(*this);
}
2023-06-12 05:56:19 +00:00
std::suspend_always initial_suspend() noexcept { return {}; }
2023-06-12 05:56:19 +00:00
std::suspend_always final_suspend() noexcept { return {}; }
2023-06-12 05:56:19 +00:00
void return_value(task_result r) noexcept { result = r; }
2023-06-12 05:56:19 +00:00
void unhandled_exception() noexcept {
std::terminate(); // TODO: handle exception
}
};
2023-06-12 05:56:19 +00:00
class awaiter {
public:
awaiter(handle_type handle) noexcept : _handle(handle) {}
2023-06-12 05:56:19 +00:00
bool await_ready() const noexcept { return false; }
2023-06-12 05:56:19 +00:00
auto await_suspend(std::coroutine_handle<>) noexcept { return _handle; }
2023-06-12 05:56:19 +00:00
task_result await_resume() noexcept { return *_handle.promise().result; }
2023-06-12 05:56:19 +00:00
private:
handle_type _handle;
};
2023-06-12 05:56:19 +00:00
node_task(handle_type handle) noexcept : _handle(handle) {}
2023-06-12 05:56:19 +00:00
node_task(node_task &&other) noexcept : _handle(other._handle) {
other._handle = nullptr;
}
2023-06-12 05:56:19 +00:00
~node_task() {
if (_handle) {
_handle.destroy();
2023-06-12 05:56:19 +00:00
}
}
node_task &operator=(node_task &&other) noexcept {
if (&other != this) {
if (_handle) {
_handle.destroy();
}
_handle = other._handle;
other._handle = nullptr;
2023-06-12 05:56:19 +00:00
}
return *this;
}
2023-06-12 05:56:19 +00:00
awaiter operator co_await() noexcept { return awaiter{_handle}; }
2023-06-12 05:56:19 +00:00
void resume() noexcept { _handle.resume(); }
2023-06-12 05:56:19 +00:00
bool done() const noexcept { return _handle.done(); }
2023-06-12 05:56:19 +00:00
task_result result() noexcept { return *_handle.promise().result; }
2023-06-12 05:56:19 +00:00
std::optional<task_result> try_result() noexcept {
return _handle.promise().result;
}
2023-06-12 05:56:19 +00:00
private:
handle_type _handle;
2023-06-12 05:56:19 +00:00
};
using yield = std::suspend_always;
template <typename Context> class behavior_node {
2023-06-12 05:56:19 +00:00
public:
virtual node_task tick(Context &ctx) noexcept = 0;
2023-06-12 05:56:19 +00:00
virtual ~behavior_node() = default;
2023-06-12 05:56:19 +00:00
};
template <typename Context>
2023-06-12 05:56:19 +00:00
class sequence_node : public behavior_node<Context> {
public:
void add_child(std::unique_ptr<behavior_node<Context>> &&child) {
_children.push_back(std::move(child));
}
node_task tick(Context &ctx) noexcept override {
std::size_t idx = 0;
while (idx < _children.size()) {
auto task = _children[idx]->tick(ctx);
while (!task.done()) {
task.resume();
if (idx < _children.size() - 1)
co_await yield{};
}
if (task.result() == task_result::failure) {
co_return task_result::failure;
}
2023-06-12 05:56:19 +00:00
++idx;
}
2023-06-12 05:56:19 +00:00
co_return task_result::success;
}
2023-06-12 05:56:19 +00:00
private:
std::vector<std::unique_ptr<behavior_node<Context>>> _children;
2023-06-12 05:56:19 +00:00
};
template <typename Context> auto sequence() {
return std::make_unique<sequence_node<Context>>();
2023-06-12 05:56:19 +00:00
}
template <typename Context>
2023-06-12 05:56:19 +00:00
class selector_node : public behavior_node<Context> {
public:
void add_child(std::unique_ptr<behavior_node<Context>> &&child) {
_children.push_back(std::move(child));
}
node_task tick(Context &ctx) noexcept override {
for (std::size_t idx = 0; idx < _children.size(); ++idx) {
auto task = _children[idx]->tick(ctx);
while (!task.done()) {
task.resume();
co_await yield{};
}
if (task.result() == task_result::success) {
co_return task_result::success;
}
2023-06-12 05:56:19 +00:00
}
co_return task_result::failure;
}
2023-06-12 05:56:19 +00:00
private:
std::vector<std::unique_ptr<behavior_node<Context>>> _children;
2023-06-12 05:56:19 +00:00
};
template <typename Context> auto selector() {
return std::make_unique<selector_node<Context>>();
2023-06-12 05:56:19 +00:00
}
template <typename Context, typename TickFn>
2023-06-12 05:56:19 +00:00
class action_node : public behavior_node<Context> {
public:
action_node(TickFn tick) noexcept : _tick(tick) {}
2023-06-12 05:56:19 +00:00
node_task tick(Context &ctx) noexcept override { return _tick(ctx); }
2023-06-12 05:56:19 +00:00
private:
TickFn _tick;
2023-06-12 05:56:19 +00:00
};
template <typename Context, typename TickFn> auto action(TickFn tick) {
return std::make_unique<action_node<Context, TickFn>>(tick);
2023-06-12 05:56:19 +00:00
}
template <typename Context, typename ConditionFn>
2023-06-12 05:56:19 +00:00
class conditional_node : public behavior_node<Context> {
public:
conditional_node(ConditionFn condition) noexcept : _condition(condition) {}
2023-06-12 05:56:19 +00:00
node_task tick(Context &ctx) noexcept override {
if (_condition(ctx)) {
co_return task_result::success;
2023-06-12 05:56:19 +00:00
}
co_return task_result::failure;
}
2023-06-12 05:56:19 +00:00
private:
ConditionFn _condition;
2023-06-12 05:56:19 +00:00
};
template <typename Context, typename ConditionFn>
2023-06-12 05:56:19 +00:00
auto conditional(ConditionFn condition) {
return std::make_unique<conditional_node<Context, ConditionFn>>(condition);
}
template <typename Context> class tree_builder {
public:
tree_builder() noexcept : _root(nullptr) {}
tree_builder &begin_sequence() {
auto sequence = std::make_unique<sequence_node<Context>>();
node_stack.push_back(sequence.get());
if (!_root) {
_root = std::move(sequence);
} else {
node_stack[node_stack.size() - 2]->add_child(std::move(sequence));
}
return *this;
}
tree_builder &end_sequence() {
if (!node_stack.empty() &&
dynamic_cast<sequence_node<Context> *>(node_stack.back()) != nullptr) {
node_stack.pop_back();
}
return *this;
}
tree_builder &begin_selector() {
auto selector = std::make_unique<selector_node<Context>>();
node_stack.push_back(selector.get());
if (!_root) {
_root = std::move(selector);
} else {
node_stack[node_stack.size() - 2]->add_child(std::move(selector));
}
return *this;
}
tree_builder &end_selector() {
if (!node_stack.empty() &&
dynamic_cast<selector_node<Context> *>(node_stack.back()) != nullptr) {
node_stack.pop_back();
}
return *this;
}
template <typename TickFn> tree_builder &action(TickFn tick) {
if (node_stack.empty()) {
throw std::runtime_error("Can't add an action outside a composite node.");
}
node_stack.back()->add_child(
std::make_unique<action_node<Context, TickFn>>(tick));
return *this;
}
template <typename ConditionFn>
tree_builder &conditional(ConditionFn condition) {
if (node_stack.empty()) {
throw std::runtime_error(
"Can't add a conditional outside a composite node.");
}
node_stack.back()->add_child(
std::make_unique<conditional_node<Context, ConditionFn>>(condition));
return *this;
}
std::unique_ptr<behavior_node<Context>> build() { return std::move(_root); }
private:
std::unique_ptr<behavior_node<Context>> _root;
std::vector<behavior_node<Context> *> node_stack;
};
} // namespace bt