Files
minimax/rust/minimax/src/agents/rhai.rs
2025-11-01 10:11:44 -07:00

265 lines
6.6 KiB
Rust

use anyhow::Result;
use itertools::{Itertools, Permutations};
use parking_lot::Mutex;
use rand::{seq::SliceRandom, Rng};
use rhai::{
packages::{
ArithmeticPackage, BasicArrayPackage, BasicFnPackage, BasicIteratorPackage,
BasicMathPackage, BasicStringPackage, LanguageCorePackage, LogicPackage, MoreStringPackage,
Package,
},
CallFnOptions, CustomType, Dynamic, Engine, EvalAltResult, OptimizationLevel, ParseError,
Position, Scope, TypeBuilder, AST,
};
use std::{sync::Arc, vec::IntoIter};
use super::Agent;
use crate::{
board::{Board, PlayerAction},
util::Symb,
};
pub struct RhaiPer<T: Clone, I: Iterator<Item = T>> {
inner: Arc<Permutations<I>>,
}
impl<T: Clone, I: Clone + Iterator<Item = T>> IntoIterator for RhaiPer<T, I> {
type Item = Vec<T>;
type IntoIter = Permutations<I>;
fn into_iter(self) -> Self::IntoIter {
(*self.inner).clone()
}
}
impl<T: Clone, I: Iterator<Item = T>> Clone for RhaiPer<T, I> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T: Clone + 'static, I: Clone + Iterator<Item = T> + 'static> CustomType for RhaiPer<T, I> {
fn build(mut builder: TypeBuilder<Self>) {
builder
.with_name("Perutations")
.is_iterable()
.with_fn("to_string", |_s: &mut Self| "Permutation {}".to_owned())
.with_fn("to_debug", |_s: &mut Self| "Permutation {}".to_owned());
}
}
//
//
//
pub struct Rhai<R: Rng + 'static> {
#[expect(dead_code)]
rng: Arc<Mutex<R>>,
engine: Engine,
script: AST,
scope: Scope<'static>,
print_callback: Arc<dyn Fn(&str) + 'static>,
}
impl<R: Rng + 'static> Rhai<R> {
pub fn new(
script: &str,
rng: R,
print_callback: impl Fn(&str) + 'static,
debug_callback: impl Fn(&str) + 'static,
) -> Result<Self, ParseError> {
let rng = Arc::new(Mutex::new(rng));
let print_callback = Arc::new(print_callback);
let engine = {
let mut engine = Engine::new_raw();
#[cfg(target_arch = "wasm32")]
fn performance() -> Result<web_sys::Performance, wasm_bindgen::JsValue> {
use wasm_bindgen::JsCast;
let global = web_sys::js_sys::global();
let performance = web_sys::js_sys::Reflect::get(&global, &"performance".into())?;
performance.dyn_into::<web_sys::Performance>()
}
#[cfg(target_arch = "wasm32")]
let start = {
let performance = performance().expect("performance should be available");
// In milliseconds
performance.now()
};
#[cfg(not(target_arch = "wasm32"))]
let start = {
use std::time::Instant;
Instant::now()
};
let max_secs: u64 = 5;
engine.on_progress(move |ops| {
if ops % 10_000 != 0 {
return None;
}
#[cfg(target_arch = "wasm32")]
let elapsed_s = {
let performance = performance().expect("performance should be available");
// In milliseconds
((performance.now() - start) / 1000.0).round() as u64
};
#[cfg(not(target_arch = "wasm32"))]
let elapsed_s = { start.elapsed().as_secs() };
if elapsed_s > max_secs {
return Some(
format!("Turn ran for more than {max_secs} seconds, exiting.").into(),
);
}
return None;
});
// Do not use FULL, rand functions are not pure
engine.set_optimization_level(OptimizationLevel::Simple);
engine.disable_symbol("eval");
engine.set_max_expr_depths(100, 100);
engine.set_max_strings_interned(1024);
engine.set_strict_variables(false);
engine.on_print({
let callback = print_callback.clone();
move |s| callback(s)
});
engine.on_debug(move |text, source, pos| {
debug_callback(&match (source, pos) {
(Some(source), Position::NONE) => format!("{source} | {text}"),
(Some(source), pos) => format!("{source} @ {pos:?} | {text}"),
(None, Position::NONE) => format!("{text}"),
(None, pos) => format!("{pos:?} | {text}"),
})
});
LanguageCorePackage::new().register_into_engine(&mut engine);
ArithmeticPackage::new().register_into_engine(&mut engine);
BasicIteratorPackage::new().register_into_engine(&mut engine);
LogicPackage::new().register_into_engine(&mut engine);
BasicStringPackage::new().register_into_engine(&mut engine);
MoreStringPackage::new().register_into_engine(&mut engine);
BasicMathPackage::new().register_into_engine(&mut engine);
BasicArrayPackage::new().register_into_engine(&mut engine);
BasicFnPackage::new().register_into_engine(&mut engine);
engine
.register_fn("rand_int", {
let rng = rng.clone();
move |from: i64, to: i64| rng.lock().gen_range(from..=to)
})
.register_fn("rand_bool", {
let rng = rng.clone();
move |p: f32| rng.lock().gen_bool(p as f64)
})
.register_fn("rand_symb", {
let rng = rng.clone();
move || Symb::new_random(&mut *rng.lock()).to_string()
})
.register_fn("rand_action", {
let rng = rng.clone();
move |board: Board| PlayerAction::new_random(&mut *rng.lock(), &board)
})
.register_fn("rand_shuffle", {
let rng = rng.clone();
move |mut vec: Vec<Dynamic>| {
vec.shuffle(&mut *rng.lock());
vec
}
})
.register_fn("is_op", |s: &str| {
Symb::from_str(s).map(|x| x.is_op()).unwrap_or(false)
})
.register_fn(
"permutations",
|v: Vec<Dynamic>, size: i64| -> Result<Dynamic, Box<EvalAltResult>> {
let size: usize = match size.try_into() {
Ok(x) => x,
Err(_) => {
return Err(format!("Invalid permutation size {size}").into());
}
};
let per = RhaiPer {
inner: v.into_iter().permutations(size).into(),
};
Ok(Dynamic::from(per))
},
);
engine
.build_type::<Board>()
.build_type::<PlayerAction>()
.build_type::<RhaiPer<Dynamic, IntoIter<Dynamic>>>();
engine
};
let script = engine.compile(script)?;
let scope = Scope::new(); // Not used
Ok(Self {
rng,
engine,
script,
scope,
print_callback,
})
}
pub fn print(&self, text: &str) {
(self.print_callback)(text);
}
}
impl<R: Rng + 'static> Agent for Rhai<R> {
type ErrorType = EvalAltResult;
fn name(&self) -> &'static str {
"Rhai"
}
fn step_min(&mut self, board: &Board) -> Result<PlayerAction, Self::ErrorType> {
let res = self.engine.call_fn_with_options::<PlayerAction>(
CallFnOptions::new().eval_ast(false),
&mut self.scope,
&self.script,
"step_min",
(board.clone(),),
);
match res {
Ok(x) => Ok(x),
Err(err) => Err(*err),
}
}
fn step_max(&mut self, board: &Board) -> Result<PlayerAction, Self::ErrorType> {
let res = self.engine.call_fn_with_options::<PlayerAction>(
CallFnOptions::new().eval_ast(false),
&mut self.scope,
&self.script,
"step_max",
(board.clone(),),
);
match res {
Ok(x) => Ok(x),
Err(err) => Err(*err),
}
}
}