diff --git a/rust/minimax/Cargo.toml b/rust/minimax/Cargo.toml new file mode 100644 index 0000000..383f9a0 --- /dev/null +++ b/rust/minimax/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "minimax" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = { workspace = true } +itertools = { workspace = true } +rand = { workspace = true } + +rhai = { workspace = true } +parking_lot = { workspace = true } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +rhai = { workspace = true, features = ["wasm-bindgen"] } +web-sys = { workspace = true } +wasm-bindgen = { workspace = true } diff --git a/rust/minimax/src/agents/mod.rs b/rust/minimax/src/agents/mod.rs new file mode 100644 index 0000000..7530d06 --- /dev/null +++ b/rust/minimax/src/agents/mod.rs @@ -0,0 +1,18 @@ +mod rhai; + +pub use rhai::Rhai; + +use crate::board::{Board, PlayerAction}; +use anyhow::Result; + +pub trait Agent { + type ErrorType; + + fn name(&self) -> &'static str; + + /// Try to minimize the value of a board. + fn step_min(&mut self, board: &Board) -> Result; + + /// Try to maximize the value of a board. + fn step_max(&mut self, board: &Board) -> Result; +} diff --git a/rust/minimax/src/agents/rhai.rs b/rust/minimax/src/agents/rhai.rs new file mode 100644 index 0000000..af19e0d --- /dev/null +++ b/rust/minimax/src/agents/rhai.rs @@ -0,0 +1,264 @@ +use anyhow::Result; +use itertools::{Itertools, Permutations}; +use parking_lot::Mutex; +use rand::{seq::SliceRandom, Rng}; +use rhai::{ + packages::{ + ArithmeticPackage, BasicArrayPackage, BasicFnPackage, BasicIteratorPackage, + BasicMathPackage, BasicStringPackage, LanguageCorePackage, LogicPackage, MoreStringPackage, + Package, + }, + CallFnOptions, CustomType, Dynamic, Engine, EvalAltResult, OptimizationLevel, ParseError, + Position, Scope, TypeBuilder, AST, +}; +use std::{sync::Arc, vec::IntoIter}; + +use super::Agent; +use crate::{ + board::{Board, PlayerAction}, + util::Symb, +}; + +pub struct RhaiPer> { + inner: Arc>, +} + +impl> IntoIterator for RhaiPer { + type Item = Vec; + type IntoIter = Permutations; + + fn into_iter(self) -> Self::IntoIter { + (*self.inner).clone() + } +} + +impl> Clone for RhaiPer { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl + 'static> CustomType for RhaiPer { + fn build(mut builder: TypeBuilder) { + builder + .with_name("Perutations") + .is_iterable() + .with_fn("to_string", |_s: &mut Self| "Permutation {}".to_owned()) + .with_fn("to_debug", |_s: &mut Self| "Permutation {}".to_owned()); + } +} + +// +// +// + +pub struct Rhai { + #[expect(dead_code)] + rng: Arc>, + + engine: Engine, + script: AST, + scope: Scope<'static>, + print_callback: Arc, +} + +impl Rhai { + pub fn new( + script: &str, + rng: R, + print_callback: impl Fn(&str) + 'static, + debug_callback: impl Fn(&str) + 'static, + ) -> Result { + let rng = Arc::new(Mutex::new(rng)); + let print_callback = Arc::new(print_callback); + + let engine = { + let mut engine = Engine::new_raw(); + + #[cfg(target_arch = "wasm32")] + fn performance() -> Result { + use wasm_bindgen::JsCast; + + let global = web_sys::js_sys::global(); + let performance = web_sys::js_sys::Reflect::get(&global, &"performance".into())?; + performance.dyn_into::() + } + + #[cfg(target_arch = "wasm32")] + let start = { + let performance = performance().expect("performance should be available"); + + // In milliseconds + performance.now() + }; + + #[cfg(not(target_arch = "wasm32"))] + let start = { + use std::time::Instant; + Instant::now() + }; + + let max_secs: u64 = 5; + engine.on_progress(move |ops| { + if ops % 10_000 != 0 { + return None; + } + + #[cfg(target_arch = "wasm32")] + let elapsed_s = { + let performance = performance().expect("performance should be available"); + + // In milliseconds + ((performance.now() - start) / 1000.0).round() as u64 + }; + + #[cfg(not(target_arch = "wasm32"))] + let elapsed_s = { start.elapsed().as_secs() }; + + if elapsed_s > max_secs { + return Some( + format!("Turn ran for more than {max_secs} seconds, exiting.").into(), + ); + } + + return None; + }); + + // Do not use FULL, rand functions are not pure + engine.set_optimization_level(OptimizationLevel::Simple); + + engine.disable_symbol("eval"); + engine.set_max_expr_depths(100, 100); + engine.set_max_strings_interned(1024); + engine.set_strict_variables(false); + engine.on_print({ + let callback = print_callback.clone(); + move |s| callback(s) + }); + engine.on_debug(move |text, source, pos| { + debug_callback(&match (source, pos) { + (Some(source), Position::NONE) => format!("{source} | {text}"), + (Some(source), pos) => format!("{source} @ {pos:?} | {text}"), + (None, Position::NONE) => format!("{text}"), + (None, pos) => format!("{pos:?} | {text}"), + }) + }); + + LanguageCorePackage::new().register_into_engine(&mut engine); + ArithmeticPackage::new().register_into_engine(&mut engine); + BasicIteratorPackage::new().register_into_engine(&mut engine); + LogicPackage::new().register_into_engine(&mut engine); + BasicStringPackage::new().register_into_engine(&mut engine); + MoreStringPackage::new().register_into_engine(&mut engine); + BasicMathPackage::new().register_into_engine(&mut engine); + BasicArrayPackage::new().register_into_engine(&mut engine); + BasicFnPackage::new().register_into_engine(&mut engine); + + engine + .register_fn("rand_int", { + let rng = rng.clone(); + move |from: i64, to: i64| rng.lock().gen_range(from..=to) + }) + .register_fn("rand_bool", { + let rng = rng.clone(); + move |p: f32| rng.lock().gen_bool(p as f64) + }) + .register_fn("rand_symb", { + let rng = rng.clone(); + move || Symb::new_random(&mut *rng.lock()).to_string() + }) + .register_fn("rand_action", { + let rng = rng.clone(); + move |board: Board| PlayerAction::new_random(&mut *rng.lock(), &board) + }) + .register_fn("rand_shuffle", { + let rng = rng.clone(); + move |mut vec: Vec| { + vec.shuffle(&mut *rng.lock()); + vec + } + }) + .register_fn("is_op", |s: &str| { + Symb::from_str(s).map(|x| x.is_op()).unwrap_or(false) + }) + .register_fn( + "permutations", + |v: Vec, size: i64| -> Result> { + let size: usize = match size.try_into() { + Ok(x) => x, + Err(_) => { + return Err(format!("Invalid permutation size {size}").into()); + } + }; + + let per = RhaiPer { + inner: v.into_iter().permutations(size).into(), + }; + + Ok(Dynamic::from(per)) + }, + ); + + engine + .build_type::() + .build_type::() + .build_type::>>(); + engine + }; + + let script = engine.compile(script)?; + let scope = Scope::new(); // Not used + + Ok(Self { + rng, + engine, + script, + scope, + print_callback, + }) + } + + pub fn print(&self, text: &str) { + (self.print_callback)(text); + } +} + +impl Agent for Rhai { + type ErrorType = EvalAltResult; + + fn name(&self) -> &'static str { + "Rhai" + } + + fn step_min(&mut self, board: &Board) -> Result { + let res = self.engine.call_fn_with_options::( + CallFnOptions::new().eval_ast(false), + &mut self.scope, + &self.script, + "step_min", + (board.clone(),), + ); + + match res { + Ok(x) => Ok(x), + Err(err) => Err(*err), + } + } + + fn step_max(&mut self, board: &Board) -> Result { + let res = self.engine.call_fn_with_options::( + CallFnOptions::new().eval_ast(false), + &mut self.scope, + &self.script, + "step_max", + (board.clone(),), + ); + + match res { + Ok(x) => Ok(x), + Err(err) => Err(*err), + } + } +} diff --git a/rust/minimax/src/board/board.rs b/rust/minimax/src/board/board.rs new file mode 100644 index 0000000..e87258a --- /dev/null +++ b/rust/minimax/src/board/board.rs @@ -0,0 +1,596 @@ +use anyhow::Result; +use itertools::Itertools; +use rhai::Array; +use rhai::CustomType; +use rhai::Dynamic; +use rhai::EvalAltResult; +use rhai::Position; +use rhai::TypeBuilder; +use std::fmt::{Debug, Display, Write}; + +use super::{PlayerAction, TreeElement}; +use crate::util::Symb; + +#[derive(Debug)] +enum InterTreeElement { + Unprocessed(Token), + Processed(TreeElement), +} + +impl InterTreeElement { + fn to_value(&self) -> Option { + Some(match self { + InterTreeElement::Processed(x) => x.clone(), + InterTreeElement::Unprocessed(Token::Value(s)) => { + if let Some(s) = s.strip_prefix('-') { + TreeElement::Neg { + r: { + if s.contains('_') { + Box::new(TreeElement::Partial(s.to_string())) + } else { + Box::new(TreeElement::Number(match s.parse() { + Ok(x) => x, + _ => return None, + })) + } + }, + } + } else if s.contains('_') { + TreeElement::Partial(s.to_string()) + } else { + TreeElement::Number(match s.parse() { + Ok(x) => x, + _ => return None, + }) + } + } + _ => return None, + }) + } +} + +#[derive(Debug, PartialEq, Clone)] +enum Token { + Value(String), + OpAdd, + OpSub, + OpMult, + OpDiv, +} + +#[derive(Clone)] +pub struct Board { + board: [Option; 11], + placed_by: [Option; 11], + + /// Number of Nones in `board` + free_spots: usize, + + /// Index of the last board index that was changed + last_placed: Option, +} + +impl Display for Board { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for c in self.board { + write!(f, "{}", c.map(|s| s.get_char().unwrap()).unwrap_or('_'))? + } + Ok(()) + } +} + +impl Debug for Board { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self, f) + } +} + +#[allow(dead_code)] +impl Board { + pub fn new() -> Self { + Self { + free_spots: 11, + board: Default::default(), + placed_by: Default::default(), + last_placed: None, + } + } + + pub fn get_board(&self) -> &[Option; 11] { + &self.board + } + + pub fn get_board_mut(&mut self) -> &mut [Option; 11] { + &mut self.board + } + + /// Get the index of the ith empty slot + pub fn ith_empty_slot(&self, mut idx: usize) -> Option { + for (i, c) in self.board.iter().enumerate() { + if c.is_none() { + if idx == 0 { + return Some(i); + } + idx -= 1; + } + } + + if idx == 0 { + Some(self.board.len() - 1) + } else { + None + } + } + + pub fn is_full(&self) -> bool { + self.free_spots == 0 + } + + pub fn prettyprint(&self) -> String { + const RESET: &str = "\x1b[0m"; + const MAGENTA: &str = "\x1b[35m"; + + let mut s = String::new(); + // Print board + for (i, (symb, _)) in self.board.iter().zip(self.placed_by.iter()).enumerate() { + match symb { + Some(symb) => write!( + s, + "{}{}{RESET}", + // If index matches last placed, draw symbol in red. + // If last_placed is None, this check will always fail + // since self.board.len is always greater than i. + if self.last_placed.unwrap_or(self.board.len()) == i { + MAGENTA + } else { + RESET + }, + symb, + ) + .unwrap(), + + None => write!(s, "_").unwrap(), + } + } + + s + } + + pub fn size(&self) -> usize { + self.board.len() + } + + pub fn get_last_placed(&self) -> Option { + self.last_placed + } + + pub fn contains(&self, s: Symb) -> bool { + self.board.iter().contains(&Some(s)) + } + + /// Is the given action valid? + pub fn can_play(&self, action: &PlayerAction) -> bool { + match &self.board[action.pos] { + Some(_) => return false, + None => { + // Check for duplicate symbols + if self.contains(action.symb) { + return false; + } + + // Check syntax + match action.symb { + Symb::Minus => { + if action.pos == self.board.len() - 1 { + return false; + } + + let r = &self.board[action.pos + 1]; + if r.is_some_and(|s| s.is_op() && !s.is_minus()) { + return false; + } + } + + Symb::Zero => { + if action.pos != 0 { + let l = &self.board[action.pos - 1]; + if l == &Some(Symb::Div) { + return false; + } + } + } + + Symb::Div | Symb::Plus | Symb::Times => { + if action.pos == 0 || action.pos == self.board.len() - 1 { + return false; + } + + let l = &self.board[action.pos - 1]; + let r = &self.board[action.pos + 1]; + + if action.symb == Symb::Div && r == &Some(Symb::Zero) { + return false; + } + + if l.is_some_and(|s| s.is_op()) + || r.is_some_and(|s| s.is_op() && !s.is_minus()) + { + return false; + } + } + _ => {} + } + } + } + + true + } + + /// Place the marked symbol at the given position. + /// Returns true for valid moves and false otherwise. + pub fn play(&mut self, action: PlayerAction, player: impl Into) -> bool { + if !self.can_play(&action) { + return false; + } + + self.board[action.pos] = Some(action.symb); + self.placed_by[action.pos] = Some(player.into()); + self.free_spots -= 1; + self.last_placed = Some(action.pos); + true + } + + fn tokenize(&self) -> Vec { + let mut tokens = Vec::new(); + let mut is_neg = true; // if true, - is negative. if false, subtract. + let mut current_num = String::new(); + + for s in self.board.iter() { + match s { + Some(Symb::Div) => { + if !current_num.is_empty() { + tokens.push(Token::Value(current_num.clone())); + current_num.clear(); + } + tokens.push(Token::OpDiv); + is_neg = true; + } + Some(Symb::Minus) => { + if is_neg { + current_num = format!("-{}", current_num); + } else { + if !current_num.is_empty() { + tokens.push(Token::Value(current_num.clone())); + current_num.clear(); + } + tokens.push(Token::OpSub); + is_neg = true; + } + } + Some(Symb::Plus) => { + if !current_num.is_empty() { + tokens.push(Token::Value(current_num.clone())); + current_num.clear(); + } + tokens.push(Token::OpAdd); + is_neg = true; + } + Some(Symb::Times) => { + if !current_num.is_empty() { + tokens.push(Token::Value(current_num.clone())); + current_num.clear(); + } + tokens.push(Token::OpMult); + is_neg = true; + } + Some(Symb::Zero) => { + current_num.push('0'); + is_neg = false; + } + Some(Symb::Number(x)) => { + current_num.push_str(&x.to_string()); + is_neg = false; + } + None => { + current_num.push('_'); + is_neg = false; + } + } + } + + if !current_num.is_empty() { + tokens.push(Token::Value(current_num.clone())); + } + + tokens + } + + pub fn to_tree(&self) -> Option { + let tokens = self.tokenize(); + + let mut tree: Vec<_> = tokens + .iter() + .map(|x| InterTreeElement::Unprocessed(x.clone())) + .collect(); + + let mut priority_level = 0; + let mut did_something; + while tree.len() > 1 { + did_something = false; + for i in 0..tree.len() { + if match priority_level { + 0 => matches!( + tree[i], + InterTreeElement::Unprocessed(Token::OpMult) + | InterTreeElement::Unprocessed(Token::OpDiv) + ), + 1 => matches!( + tree[i], + InterTreeElement::Unprocessed(Token::OpAdd) + | InterTreeElement::Unprocessed(Token::OpSub) + ), + _ => false, + } { + did_something = true; + + if i == 0 || i + 1 >= tree.len() { + return None; + } + + let l = tree[i - 1].to_value()?; + let r = tree[i + 1].to_value()?; + + let v = match tree[i] { + InterTreeElement::Unprocessed(Token::OpAdd) => TreeElement::Add { + l: Box::new(l), + r: Box::new(r), + }, + InterTreeElement::Unprocessed(Token::OpDiv) => TreeElement::Div { + l: Box::new(l), + r: Box::new(r), + }, + InterTreeElement::Unprocessed(Token::OpMult) => TreeElement::Mul { + l: Box::new(l), + r: Box::new(r), + }, + InterTreeElement::Unprocessed(Token::OpSub) => TreeElement::Sub { + l: Box::new(l), + r: Box::new(r), + }, + _ => unreachable!(), + }; + + tree.remove(i - 1); + tree.remove(i - 1); + tree[i - 1] = InterTreeElement::Processed(v); + break; + } + } + + if !did_something { + priority_level += 1; + } + } + + Some(match tree.into_iter().next().unwrap() { + InterTreeElement::Processed(x) => x, + x => x.to_value()?, + }) + } + + pub fn evaluate(&self) -> Option { + self.to_tree()?.evaluate() + } + + pub fn from_board(board: [Option; 11]) -> Self { + let free_spots = board.iter().filter(|x| x.is_none()).count(); + Self { + board, + placed_by: Default::default(), + free_spots, + last_placed: None, + } + } + + /// Parse a board from a string + pub fn from_string(s: &str) -> Option { + if s.len() != 11 { + return None; + } + + let x = s + .chars() + .filter_map(|c| { + if c == '_' { + Some(None) + } else { + Symb::from_char(c).map(Some) + } + }) + .collect::>(); + + if x.len() != 11 { + return None; + } + + let mut free_spots = 11; + let mut board = [None; 11]; + for i in 0..x.len() { + board[i] = x[i]; + if x[i].is_some() { + free_spots -= 1; + } + } + + Some(Self { + board, + placed_by: Default::default(), + free_spots, + last_placed: None, + }) + } +} + +impl IntoIterator for Board { + type Item = String; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.board + .iter() + .map(|x| x.map(|x| x.to_string()).unwrap_or_default()) + .collect::>() + .into_iter() + } +} + +impl CustomType for Board { + fn build(mut builder: TypeBuilder) { + builder + .with_name("Board") + .is_iterable() + .with_fn("to_string", |s: &mut Self| format!("{}", s)) + .with_fn("to_debug", |s: &mut Self| format!("{:?}", s)) + .with_fn("size", |s: &mut Self| s.board.len() as i64) + .with_fn("len", |s: &mut Self| s.board.len() as i64) + .with_fn("is_full", |s: &mut Self| s.is_full()) + .with_fn("free_spots", |s: &mut Self| s.free_spots) + .with_fn("play", |s: &mut Self, act: PlayerAction| { + s.play(act, "NONE".to_owned()) // Player doesn't matter + }) + .with_fn("ith_free_slot", |s: &mut Self, idx: usize| { + s.ith_empty_slot(idx).map(|x| x as i64).unwrap_or(-1) + }) + .with_fn("can_play", |s: &mut Self, act: PlayerAction| { + s.can_play(&act) + }) + .with_fn("contains", |s: &mut Self, sym: &str| { + match Symb::from_str(sym) { + None => false, + Some(x) => s.contains(x), + } + }) + .with_fn("contains", |s: &mut Self, sym: i64| { + let sym = sym.to_string(); + match Symb::from_str(&sym) { + None => false, + Some(x) => s.contains(x), + } + }) + .with_fn("evaluate", |s: &mut Self| -> Dynamic { + s.evaluate().map(|x| x.into()).unwrap_or(().into()) + }) + .with_fn("free_spots_idx", |s: &mut Self| -> Array { + s.board + .iter() + .enumerate() + .filter(|(_, x)| x.is_none()) + .map(|(i, _)| i as i64) + .map(|x| x.into()) + .collect::>() + }) + .with_indexer_get( + |s: &mut Self, idx: i64| -> Result> { + if idx as usize >= s.board.len() { + return Err( + EvalAltResult::ErrorIndexNotFound(idx.into(), Position::NONE).into(), + ); + } + + let idx = idx as usize; + return Ok(s.board[idx].map(|x| x.to_string()).unwrap_or_default()); + }, + ) + .with_indexer_set( + |s: &mut Self, idx: i64, val: String| -> Result<(), Box> { + let idx: usize = match idx.try_into() { + Ok(x) => x, + Err(_) => { + return Err(EvalAltResult::ErrorIndexNotFound( + idx.into(), + Position::NONE, + ) + .into()); + } + }; + + if idx >= s.board.len() { + return Err(EvalAltResult::ErrorIndexNotFound( + (idx as i64).into(), + Position::NONE, + ) + .into()); + } + + match Symb::from_str(&val) { + None => return Err(format!("Invalid symbol {val}").into()), + Some(x) => { + s.board[idx] = Some(x); + s.placed_by[idx] = Some("NONE".to_owned()); // Arbitrary + } + } + + return Ok(()); + }, + ) + .with_indexer_set( + |s: &mut Self, idx: i64, _val: ()| -> Result<(), Box> { + let idx: usize = match idx.try_into() { + Ok(x) => x, + Err(_) => { + return Err(EvalAltResult::ErrorIndexNotFound( + idx.into(), + Position::NONE, + ) + .into()); + } + }; + + if idx >= s.board.len() { + return Err(EvalAltResult::ErrorIndexNotFound( + (idx as i64).into(), + Position::NONE, + ) + .into()); + } + + s.board[idx] = None; + s.placed_by[idx] = None; + + return Ok(()); + }, + ) + .with_indexer_set( + |s: &mut Self, idx: i64, val: i64| -> Result<(), Box> { + let idx: usize = match idx.try_into() { + Ok(x) => x, + Err(_) => { + return Err(EvalAltResult::ErrorIndexNotFound( + idx.into(), + Position::NONE, + ) + .into()); + } + }; + + if idx >= s.board.len() { + return Err(EvalAltResult::ErrorIndexNotFound( + (idx as i64).into(), + Position::NONE, + ) + .into()); + } + + match Symb::from_str(&val.to_string()) { + None => return Err(format!("Invalid symbol {val}").into()), + Some(x) => { + s.board[idx] = Some(x); + s.placed_by[idx] = Some("NULL".to_owned()); // Arbitrary + } + } + + return Ok(()); + }, + ); + } +} diff --git a/rust/minimax/src/board/mod.rs b/rust/minimax/src/board/mod.rs new file mode 100644 index 0000000..1162e68 --- /dev/null +++ b/rust/minimax/src/board/mod.rs @@ -0,0 +1,77 @@ +#[allow(clippy::module_inception)] +mod board; +mod tree; + +use rand::Rng; +use rhai::{CustomType, EvalAltResult, TypeBuilder}; +use std::fmt::Display; + +pub use board::Board; +pub use tree::TreeElement; + +use crate::util::Symb; + +#[derive(Debug, Clone, Copy)] +pub struct PlayerAction { + pub symb: Symb, + pub pos: usize, +} + +impl Display for PlayerAction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} at {}", self.symb, self.pos) + } +} + +impl PlayerAction { + pub fn new_random(rng: &mut R, board: &Board) -> Self { + let n = board.size(); + let pos = rng.gen_range(0..n); + let symb = Symb::new_random(rng); + PlayerAction { symb, pos } + } +} + +impl CustomType for PlayerAction { + fn build(mut builder: TypeBuilder) { + builder + .with_name("Action") + .with_fn( + "Action", + |symb: &str, pos: i64| -> Result> { + let symb = match Symb::from_str(symb) { + Some(x) => x, + None => return Err(format!("Invalid symbol {symb:?}").into()), + }; + + Ok(Self { + symb, + pos: pos as usize, + }) + }, + ) + .with_fn( + "Action", + |symb: i64, pos: i64| -> Result> { + let symb = symb.to_string(); + let symb = match Symb::from_str(&symb) { + Some(x) => x, + None => return Err(format!("Invalid symbol {symb:?}").into()), + }; + + Ok(Self { + symb, + pos: pos as usize, + }) + }, + ) + .with_fn("to_string", |s: &mut Self| -> String { + format!("Action {{{} at {}}}", s.symb, s.pos) + }) + .with_fn("to_debug", |s: &mut Self| -> String { + format!("Action {{{} at {}}}", s.symb, s.pos) + }) + .with_get("symb", |s: &mut Self| s.symb.to_string()) + .with_get("pos", |s: &mut Self| s.pos); + } +} diff --git a/rust/minimax/src/board/tree.rs b/rust/minimax/src/board/tree.rs new file mode 100644 index 0000000..93e67af --- /dev/null +++ b/rust/minimax/src/board/tree.rs @@ -0,0 +1,143 @@ +use std::fmt::{Debug, Display}; + +#[derive(PartialEq, Clone)] +pub enum TreeElement { + Partial(String), + Number(f32), + Add { + l: Box, + r: Box, + }, + Sub { + l: Box, + r: Box, + }, + Mul { + l: Box, + r: Box, + }, + Div { + l: Box, + r: Box, + }, + Neg { + r: Box, + }, +} + +impl Display for TreeElement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Partial(s) => write!(f, "{s}")?, + Self::Number(n) => write!(f, "{n}")?, + Self::Add { l, r } => write!(f, "({l}+{r})")?, + Self::Div { l, r } => write!(f, "({l}÷{r})")?, + Self::Mul { l, r } => write!(f, "({l}×{r})")?, + Self::Sub { l, r } => write!(f, "({l}-{r})")?, + Self::Neg { r } => write!(f, "(-{r})")?, + } + Ok(()) + } +} + +impl Debug for TreeElement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self, f) + } +} + +#[allow(dead_code)] +impl TreeElement { + pub fn left(&self) -> Option<&TreeElement> { + match self { + Self::Add { l, .. } + | Self::Sub { l, .. } + | Self::Mul { l, .. } + | Self::Div { l, .. } => Some(&**l), + _ => None, + } + } + + pub fn right(&self) -> Option<&TreeElement> { + match self { + Self::Add { r, .. } + | Self::Neg { r, .. } + | Self::Sub { r, .. } + | Self::Mul { r, .. } + | Self::Div { r, .. } => Some(&**r), + _ => None, + } + } + + pub fn left_mut(&mut self) -> Option<&mut TreeElement> { + match self { + Self::Add { l, .. } + | Self::Sub { l, .. } + | Self::Mul { l, .. } + | Self::Div { l, .. } => Some(&mut **l), + _ => None, + } + } + + pub fn right_mut(&mut self) -> Option<&mut TreeElement> { + match self { + Self::Add { r, .. } + | Self::Neg { r, .. } + | Self::Sub { r, .. } + | Self::Mul { r, .. } + | Self::Div { r, .. } => Some(&mut **r), + _ => None, + } + } + + pub fn evaluate(&self) -> Option { + match self { + Self::Number(x) => Some(*x), + // Try to parse strings of a partial + Self::Partial(s) => s.parse().ok(), + Self::Add { l, r } => { + let l = l.evaluate(); + let r = r.evaluate(); + if let (Some(l), Some(r)) = (l, r) { + Some(l + r) + } else { + None + } + } + Self::Mul { l, r } => { + let l = l.evaluate(); + let r = r.evaluate(); + if let (Some(l), Some(r)) = (l, r) { + Some(l * r) + } else { + None + } + } + Self::Div { l, r } => { + let l = l.evaluate(); + let r = r.evaluate(); + + if r == Some(0.0) { + None + } else if let (Some(l), Some(r)) = (l, r) { + Some(l / r) + } else { + None + } + } + Self::Sub { l, r } => { + let l = l.evaluate(); + let r = r.evaluate(); + if let (Some(l), Some(r)) = (l, r) { + Some(l - r) + } else { + None + } + } + Self::Neg { r } => { + let r = r.evaluate(); + r.map(|r| -r) + } + } + } +} diff --git a/rust/minimax/src/lib.rs b/rust/minimax/src/lib.rs new file mode 100644 index 0000000..ce59f43 --- /dev/null +++ b/rust/minimax/src/lib.rs @@ -0,0 +1,3 @@ +pub mod agents; +pub mod board; +pub mod util; diff --git a/rust/minimax/src/util.rs b/rust/minimax/src/util.rs new file mode 100644 index 0000000..2769898 --- /dev/null +++ b/rust/minimax/src/util.rs @@ -0,0 +1,111 @@ +use std::{ + fmt::{Debug, Display}, + num::NonZeroU8, +}; + +use rand::Rng; + +#[derive(PartialEq, Eq, Clone, Copy, Hash)] + +pub enum Symb { + Number(NonZeroU8), + Zero, + Plus, + Minus, + Times, + Div, +} + +impl Display for Symb { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Number(x) => write!(f, "{x}")?, + Self::Zero => write!(f, "0")?, + Self::Plus => write!(f, "+")?, + Self::Minus => write!(f, "-")?, + Self::Div => write!(f, "÷")?, + Self::Times => write!(f, "×")?, + } + Ok(()) + } +} +impl Debug for Symb { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +impl Symb { + /// Is this symbol a plain binary operator? + pub fn is_op(&self) -> bool { + matches!(self, Symb::Div | Symb::Plus | Symb::Times | Symb::Minus) + } + + pub fn is_minus(&self) -> bool { + self == &Self::Minus + } + + pub fn new_random(rng: &mut R) -> Self { + match rng.gen_range(0..=13) { + 0 => Symb::Zero, + n @ 1..=9 => Symb::Number(NonZeroU8::new(n).unwrap()), + 10 => Symb::Div, + 11 => Symb::Minus, + 12 => Symb::Plus, + 13 => Symb::Times, + _ => unreachable!(), + } + } + + pub const fn get_char(&self) -> Option { + match self { + Self::Plus => Some('+'), + Self::Minus => Some('-'), + Self::Times => Some('×'), + Self::Div => Some('÷'), + Self::Zero => Some('0'), + Self::Number(x) => match x.get() { + 1 => Some('1'), + 2 => Some('2'), + 3 => Some('3'), + 4 => Some('4'), + 5 => Some('5'), + 6 => Some('6'), + 7 => Some('7'), + 8 => Some('8'), + 9 => Some('9'), + _ => None, + }, + } + } + + pub fn from_str(s: &str) -> Option { + if s.chars().count() != 1 { + return None; + } + + Self::from_char(s.chars().next()?) + } + + pub const fn from_char(c: char) -> Option { + match c { + '1' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(1) })), + '2' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(2) })), + '3' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(3) })), + '4' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(4) })), + '5' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(5) })), + '6' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(6) })), + '7' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(7) })), + '8' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(8) })), + '9' => Some(Self::Number(unsafe { NonZeroU8::new_unchecked(9) })), + '0' => Some(Self::Zero), + '+' => Some(Self::Plus), + '-' => Some(Self::Minus), + '*' => Some(Self::Times), + '/' => Some(Self::Div), + '×' => Some(Self::Times), + '÷' => Some(Self::Div), + _ => None, + } + } +}