265 lines
6.6 KiB
Rust
265 lines
6.6 KiB
Rust
use anyhow::Result;
|
|
use itertools::{Itertools, Permutations};
|
|
use parking_lot::Mutex;
|
|
use rand::{seq::SliceRandom, Rng};
|
|
use rhai::{
|
|
packages::{
|
|
ArithmeticPackage, BasicArrayPackage, BasicFnPackage, BasicIteratorPackage,
|
|
BasicMathPackage, BasicStringPackage, LanguageCorePackage, LogicPackage, MoreStringPackage,
|
|
Package,
|
|
},
|
|
CallFnOptions, CustomType, Dynamic, Engine, EvalAltResult, OptimizationLevel, ParseError,
|
|
Position, Scope, TypeBuilder, AST,
|
|
};
|
|
use std::{sync::Arc, vec::IntoIter};
|
|
|
|
use super::Agent;
|
|
use crate::{
|
|
board::{Board, PlayerAction},
|
|
util::Symb,
|
|
};
|
|
|
|
pub struct RhaiPer<T: Clone, I: Iterator<Item = T>> {
|
|
inner: Arc<Permutations<I>>,
|
|
}
|
|
|
|
impl<T: Clone, I: Clone + Iterator<Item = T>> IntoIterator for RhaiPer<T, I> {
|
|
type Item = Vec<T>;
|
|
type IntoIter = Permutations<I>;
|
|
|
|
fn into_iter(self) -> Self::IntoIter {
|
|
(*self.inner).clone()
|
|
}
|
|
}
|
|
|
|
impl<T: Clone, I: Iterator<Item = T>> Clone for RhaiPer<T, I> {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
inner: self.inner.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: Clone + 'static, I: Clone + Iterator<Item = T> + 'static> CustomType for RhaiPer<T, I> {
|
|
fn build(mut builder: TypeBuilder<Self>) {
|
|
builder
|
|
.with_name("Perutations")
|
|
.is_iterable()
|
|
.with_fn("to_string", |_s: &mut Self| "Permutation {}".to_owned())
|
|
.with_fn("to_debug", |_s: &mut Self| "Permutation {}".to_owned());
|
|
}
|
|
}
|
|
|
|
//
|
|
//
|
|
//
|
|
|
|
pub struct Rhai<R: Rng + 'static> {
|
|
#[expect(dead_code)]
|
|
rng: Arc<Mutex<R>>,
|
|
|
|
engine: Engine,
|
|
script: AST,
|
|
scope: Scope<'static>,
|
|
print_callback: Arc<dyn Fn(&str) + 'static>,
|
|
}
|
|
|
|
impl<R: Rng + 'static> Rhai<R> {
|
|
pub fn new(
|
|
script: &str,
|
|
rng: R,
|
|
print_callback: impl Fn(&str) + 'static,
|
|
debug_callback: impl Fn(&str) + 'static,
|
|
) -> Result<Self, ParseError> {
|
|
let rng = Arc::new(Mutex::new(rng));
|
|
let print_callback = Arc::new(print_callback);
|
|
|
|
let engine = {
|
|
let mut engine = Engine::new_raw();
|
|
|
|
#[cfg(target_arch = "wasm32")]
|
|
fn performance() -> Result<web_sys::Performance, wasm_bindgen::JsValue> {
|
|
use wasm_bindgen::JsCast;
|
|
|
|
let global = web_sys::js_sys::global();
|
|
let performance = web_sys::js_sys::Reflect::get(&global, &"performance".into())?;
|
|
performance.dyn_into::<web_sys::Performance>()
|
|
}
|
|
|
|
#[cfg(target_arch = "wasm32")]
|
|
let start = {
|
|
let performance = performance().expect("performance should be available");
|
|
|
|
// In milliseconds
|
|
performance.now()
|
|
};
|
|
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
let start = {
|
|
use std::time::Instant;
|
|
Instant::now()
|
|
};
|
|
|
|
let max_secs: u64 = 5;
|
|
engine.on_progress(move |ops| {
|
|
if ops % 10_000 != 0 {
|
|
return None;
|
|
}
|
|
|
|
#[cfg(target_arch = "wasm32")]
|
|
let elapsed_s = {
|
|
let performance = performance().expect("performance should be available");
|
|
|
|
// In milliseconds
|
|
((performance.now() - start) / 1000.0).round() as u64
|
|
};
|
|
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
let elapsed_s = { start.elapsed().as_secs() };
|
|
|
|
if elapsed_s > max_secs {
|
|
return Some(
|
|
format!("Turn ran for more than {max_secs} seconds, exiting.").into(),
|
|
);
|
|
}
|
|
|
|
return None;
|
|
});
|
|
|
|
// Do not use FULL, rand functions are not pure
|
|
engine.set_optimization_level(OptimizationLevel::Simple);
|
|
|
|
engine.disable_symbol("eval");
|
|
engine.set_max_expr_depths(100, 100);
|
|
engine.set_max_strings_interned(1024);
|
|
engine.set_strict_variables(false);
|
|
engine.on_print({
|
|
let callback = print_callback.clone();
|
|
move |s| callback(s)
|
|
});
|
|
engine.on_debug(move |text, source, pos| {
|
|
debug_callback(&match (source, pos) {
|
|
(Some(source), Position::NONE) => format!("{source} | {text}"),
|
|
(Some(source), pos) => format!("{source} @ {pos:?} | {text}"),
|
|
(None, Position::NONE) => format!("{text}"),
|
|
(None, pos) => format!("{pos:?} | {text}"),
|
|
})
|
|
});
|
|
|
|
LanguageCorePackage::new().register_into_engine(&mut engine);
|
|
ArithmeticPackage::new().register_into_engine(&mut engine);
|
|
BasicIteratorPackage::new().register_into_engine(&mut engine);
|
|
LogicPackage::new().register_into_engine(&mut engine);
|
|
BasicStringPackage::new().register_into_engine(&mut engine);
|
|
MoreStringPackage::new().register_into_engine(&mut engine);
|
|
BasicMathPackage::new().register_into_engine(&mut engine);
|
|
BasicArrayPackage::new().register_into_engine(&mut engine);
|
|
BasicFnPackage::new().register_into_engine(&mut engine);
|
|
|
|
engine
|
|
.register_fn("rand_int", {
|
|
let rng = rng.clone();
|
|
move |from: i64, to: i64| rng.lock().gen_range(from..=to)
|
|
})
|
|
.register_fn("rand_bool", {
|
|
let rng = rng.clone();
|
|
move |p: f32| rng.lock().gen_bool(p as f64)
|
|
})
|
|
.register_fn("rand_symb", {
|
|
let rng = rng.clone();
|
|
move || Symb::new_random(&mut *rng.lock()).to_string()
|
|
})
|
|
.register_fn("rand_action", {
|
|
let rng = rng.clone();
|
|
move |board: Board| PlayerAction::new_random(&mut *rng.lock(), &board)
|
|
})
|
|
.register_fn("rand_shuffle", {
|
|
let rng = rng.clone();
|
|
move |mut vec: Vec<Dynamic>| {
|
|
vec.shuffle(&mut *rng.lock());
|
|
vec
|
|
}
|
|
})
|
|
.register_fn("is_op", |s: &str| {
|
|
Symb::from_str(s).map(|x| x.is_op()).unwrap_or(false)
|
|
})
|
|
.register_fn(
|
|
"permutations",
|
|
|v: Vec<Dynamic>, size: i64| -> Result<Dynamic, Box<EvalAltResult>> {
|
|
let size: usize = match size.try_into() {
|
|
Ok(x) => x,
|
|
Err(_) => {
|
|
return Err(format!("Invalid permutation size {size}").into());
|
|
}
|
|
};
|
|
|
|
let per = RhaiPer {
|
|
inner: v.into_iter().permutations(size).into(),
|
|
};
|
|
|
|
Ok(Dynamic::from(per))
|
|
},
|
|
);
|
|
|
|
engine
|
|
.build_type::<Board>()
|
|
.build_type::<PlayerAction>()
|
|
.build_type::<RhaiPer<Dynamic, IntoIter<Dynamic>>>();
|
|
engine
|
|
};
|
|
|
|
let script = engine.compile(script)?;
|
|
let scope = Scope::new(); // Not used
|
|
|
|
Ok(Self {
|
|
rng,
|
|
engine,
|
|
script,
|
|
scope,
|
|
print_callback,
|
|
})
|
|
}
|
|
|
|
pub fn print(&self, text: &str) {
|
|
(self.print_callback)(text);
|
|
}
|
|
}
|
|
|
|
impl<R: Rng + 'static> Agent for Rhai<R> {
|
|
type ErrorType = EvalAltResult;
|
|
|
|
fn name(&self) -> &'static str {
|
|
"Rhai"
|
|
}
|
|
|
|
fn step_min(&mut self, board: &Board) -> Result<PlayerAction, Self::ErrorType> {
|
|
let res = self.engine.call_fn_with_options::<PlayerAction>(
|
|
CallFnOptions::new().eval_ast(false),
|
|
&mut self.scope,
|
|
&self.script,
|
|
"step_min",
|
|
(board.clone(),),
|
|
);
|
|
|
|
match res {
|
|
Ok(x) => Ok(x),
|
|
Err(err) => Err(*err),
|
|
}
|
|
}
|
|
|
|
fn step_max(&mut self, board: &Board) -> Result<PlayerAction, Self::ErrorType> {
|
|
let res = self.engine.call_fn_with_options::<PlayerAction>(
|
|
CallFnOptions::new().eval_ast(false),
|
|
&mut self.scope,
|
|
&self.script,
|
|
"step_max",
|
|
(board.clone(),),
|
|
);
|
|
|
|
match res {
|
|
Ok(x) => Ok(x),
|
|
Err(err) => Err(*err),
|
|
}
|
|
}
|
|
}
|