From ffae0342f18cc5e4f2827510dbcfba94f117d0c4 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 4 Mar 2024 18:46:30 -0800 Subject: [PATCH] Added basic minmax calculations --- Cargo.lock | 46 ++++ Cargo.toml | 1 + src/agents/minmaxtree.rs | 441 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 488 insertions(+) create mode 100644 src/agents/minmaxtree.rs diff --git a/Cargo.lock b/Cargo.lock index a1654cc..6da9829 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,31 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + [[package]] name = "either" version = "1.10.0" @@ -82,6 +107,7 @@ dependencies = [ "anyhow", "itertools", "rand", + "rayon", "termion", ] @@ -121,6 +147,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index ff1928d..b4fd351 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,5 @@ edition = "2021" anyhow = "1.0.80" itertools = "0.12.1" rand = "0.8.5" +rayon = "1.9.0" termion = "3.0.0" diff --git a/src/agents/minmaxtree.rs b/src/agents/minmaxtree.rs new file mode 100644 index 0000000..aac1a3f --- /dev/null +++ b/src/agents/minmaxtree.rs @@ -0,0 +1,441 @@ +use std::{ + fmt::{Debug, Display}, + iter, + num::NonZeroU8, + thread, +}; + +use anyhow::Result; +use itertools::Itertools; +use rayon::iter::{ParallelBridge, ParallelIterator}; + +use super::{MaximizerAgent, MinimizerAgent, RandomAgent}; +use crate::{ + board::{Board, PlayerAction, TreeElement}, + util::Symb, +}; + +pub struct MinMaxTree {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TreeDir { + Right, + Left, + This, +} + +#[derive(Clone, Copy)] +struct TreeCoords { + len: usize, + coords: [TreeDir; 4], + inversion: [bool; 4], +} + +impl Display for TreeCoords { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.get_inversion() { + write!(f, "-")? + } else { + write!(f, "+")? + } + + for c in self.coords { + match c { + TreeDir::Left => write!(f, "L")?, + TreeDir::Right => write!(f, "R")?, + TreeDir::This => break, + } + } + + Ok(()) + } +} + +impl Debug for TreeCoords { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +#[allow(dead_code)] +impl TreeCoords { + pub fn new() -> Self { + Self { + len: 0, + coords: [TreeDir::This; 4], + inversion: [false; 4], + } + } + + pub fn push(&mut self, dir: TreeDir, invert: bool) { + if self.len == 4 || dir == TreeDir::This { + return; + } + + self.coords[self.len] = dir; + self.inversion[self.len] = invert; + self.len += 1; + } + + pub fn pop(&mut self) -> Option<(TreeDir, bool)> { + if self.len == 0 { + return None; + } + + self.len -= 1; + let dir = self.coords[self.len]; + let inv = self.inversion[self.len]; + self.coords[self.len] = TreeDir::This; + self.inversion[self.len] = false; + Some((dir, inv)) + } + + pub fn get_inversion(&self) -> bool { + if self.len == 0 { + false + } else { + self.inversion[self.len - 1] + } + } + + pub fn get_from<'a>(&self, mut tree: &'a TreeElement) -> Option<&'a TreeElement> { + for i in 0..self.len { + match &self.coords[i] { + TreeDir::Left => { + if let Some(t) = tree.left() { + tree = t + } else { + return None; + } + } + + TreeDir::Right => { + if let Some(t) = tree.right() { + tree = t + } else { + return None; + } + } + + TreeDir::This => return Some(tree), + } + } + + Some(tree) + } + + pub fn get_from_mut<'a>(&self, mut tree: &'a mut TreeElement) -> Option<&'a mut TreeElement> { + for i in 0..self.len { + match &self.coords[i] { + TreeDir::Left => { + if let Some(t) = tree.left_mut() { + tree = t + } else { + return None; + } + } + + TreeDir::Right => { + if let Some(t) = tree.right_mut() { + tree = t + } else { + return None; + } + } + + TreeDir::This => return Some(tree), + } + } + + Some(tree) + } +} + +/// Count the number of free spaces in partials we want to minimize +fn count_min_slots(tree: &TreeElement, partials: &[TreeCoords]) -> usize { + partials + .iter() + .filter(|x| x.get_inversion()) + .map(|x| match x.get_from(tree) { + Some(TreeElement::Partial(s)) => s.chars().filter(|x| *x == '_').count(), + _ => unreachable!(), + }) + .sum() +} + +/// Count the number of free spaces in partials we want to maximize +fn count_max_slots(tree: &TreeElement, partials: &[TreeCoords]) -> usize { + partials + .iter() + .filter(|x| !x.get_inversion()) + .map(|x| match x.get_from(tree) { + Some(TreeElement::Partial(s)) => s.chars().filter(|x| *x == '_').count(), + _ => unreachable!(), + }) + .sum() +} + +/// Find the coordinates of all partials in the given tree +fn find_partials(tree: &TreeElement) -> Vec { + let mut partials = Vec::new(); + let mut current_coords = TreeCoords::new(); + + loop { + let t = current_coords.get_from(tree).unwrap(); + match t { + TreeElement::Number(_) | TreeElement::Partial(_) => { + if let TreeElement::Partial(_) = t { + partials.push(current_coords); + } + + loop { + match current_coords.pop() { + Some((TreeDir::Left, _)) => { + current_coords.push( + TreeDir::Right, + match current_coords.get_from(tree) { + Some(TreeElement::Add { .. }) => current_coords.get_inversion(), + Some(TreeElement::Mul { .. }) => current_coords.get_inversion(), + Some(TreeElement::Sub { .. }) => { + !current_coords.get_inversion() + } + Some(TreeElement::Div { .. }) => { + !current_coords.get_inversion() + } + _ => unreachable!(), + }, + ); + break; + } + Some((TreeDir::Right, _)) => {} + Some((TreeDir::This, _)) => unreachable!(), + None => return partials, + } + } + } + TreeElement::Div { .. } + | TreeElement::Mul { .. } + | TreeElement::Sub { .. } + | TreeElement::Add { .. } => current_coords.push(TreeDir::Left, current_coords.get_inversion()), + TreeElement::Neg { .. } => { + current_coords.push(TreeDir::Right, !current_coords.get_inversion()) + } + } + } +} + +fn fill_maxs( + tree: &TreeElement, + partials: &[TreeCoords], + mut numbers: impl Iterator, +) -> TreeElement { + let mut tmp_tree = tree.clone(); + for p in partials.iter().filter(|x| !x.get_inversion()) { + let x = p.get_from_mut(&mut tmp_tree).unwrap(); + + let x_str = match x { + TreeElement::Partial(s) => s, + _ => unreachable!(), + }; + let mut new_str = String::new(); + for c in x_str.chars() { + if c == '_' { + new_str.push_str(&format!("{}", numbers.next().unwrap())) + } else { + new_str.push(c); + } + } + *x = TreeElement::Number(new_str.parse().unwrap()) + } + + tmp_tree +} + +fn fill_mins( + tree: &TreeElement, + partials: &[TreeCoords], + mut numbers: impl Iterator, +) -> TreeElement { + let mut tmp_tree = tree.clone(); + for p in partials.iter().filter(|x| x.get_inversion()) { + let x = p.get_from_mut(&mut tmp_tree).unwrap(); + + let x_str = match x { + TreeElement::Partial(s) => s, + _ => unreachable!(), + }; + let mut new_str = String::new(); + for c in x_str.chars() { + if c == '_' { + new_str.push_str(&format!("{}", numbers.next().unwrap())) + } else { + new_str.push(c); + } + } + *x = TreeElement::Number(new_str.parse().unwrap()) + } + + tmp_tree +} + +fn find_best_maxs(tree: &TreeElement, partials: &[TreeCoords], maxs: &[Symb]) -> Vec { + // Fill maximizer slots in arbitrary order + let min_tree_base = fill_mins( + tree, + partials, + iter::repeat(Symb::Number(NonZeroU8::new(5).unwrap())), + ); + + let trees: Vec<(f32, Vec<&Symb>)> = maxs + .iter() + .permutations(maxs.len()) + .unique() + .par_bridge() + .filter_map(|l| { + let mut i = l.iter(); + let mut tmp_tree = min_tree_base.clone(); + for p in partials.iter().filter(|x| !x.get_inversion()) { + let x = p.get_from_mut(&mut tmp_tree).unwrap(); + + let x_str = match x { + TreeElement::Partial(s) => s, + _ => unreachable!(), + }; + let mut new_str = String::new(); + for c in x_str.chars() { + if c == '_' { + new_str.push_str(&format!("{}", i.next().unwrap())) + } else { + new_str.push(c); + } + } + *x = TreeElement::Number(new_str.parse().unwrap()) + } + + tmp_tree.evaluate().map(|x| (x, l)) + }) + .collect(); + + let mut max_list: Option> = None; + let mut best_value: Option = None; + + for (x, list) in trees { + if let Some(m) = best_value { + if m < x { + best_value = Some(x); + max_list = Some(list); + } + } else { + best_value = Some(x); + max_list = Some(list); + } + } + + max_list.unwrap().into_iter().cloned().collect() +} + +fn find_best_mins(tree: &TreeElement, partials: &[TreeCoords], mins: &[Symb]) -> Vec { + // Fill maximizer slots in arbitrary order + let min_tree_base = fill_maxs( + tree, + partials, + iter::repeat(Symb::Number(NonZeroU8::new(5).unwrap())), + ); + + let trees: Vec<(f32, Vec<&Symb>)> = mins + .iter() + .permutations(mins.len()) + .unique() + .par_bridge() + .filter_map(|l| { + let mut i = l.iter(); + let mut tmp_tree = min_tree_base.clone(); + for p in partials.iter().filter(|x| x.get_inversion()) { + let x = p.get_from_mut(&mut tmp_tree).unwrap(); + + let x_str = match x { + TreeElement::Partial(s) => s, + _ => unreachable!(), + }; + let mut new_str = String::new(); + for c in x_str.chars() { + if c == '_' { + new_str.push_str(&format!("{}", i.next().unwrap())) + } else { + new_str.push(c); + } + } + *x = TreeElement::Number(new_str.parse().unwrap()) + } + + tmp_tree.evaluate().map(|x| (x, l)) + }) + .collect(); + + let mut min_list: Option> = None; + let mut best_value: Option = None; + + for (x, list) in trees { + if let Some(m) = best_value { + if m < x { + best_value = Some(x); + min_list = Some(list); + } + } else { + best_value = Some(x); + min_list = Some(list); + } + } + + min_list.unwrap().into_iter().cloned().collect() +} + +impl MinMaxTree {} + +impl MinimizerAgent for MinMaxTree { + fn step_min(&mut self, board: &Board) -> Result { + let tree = board.to_tree(); + let partials = find_partials(&tree); + + let max_slots = count_max_slots(&tree, &partials); + let min_slots = count_min_slots(&tree, &partials); + + let available_numbers = (0..=9) + .map(|x| match x { + 0 => Symb::Zero, + x => Symb::Number(NonZeroU8::new(x).unwrap()), + }) + .filter(|x| !board.contains(*x)) + .collect::>(); + + if available_numbers.len() < max_slots { + return RandomAgent {}.step_max(board); + } + + // Assume these won't ever overlap + // (that is, min_slots + max_slots <= available_numbers.len) + let mins: Vec = available_numbers[0..min_slots].to_vec(); + let maxs: Vec = available_numbers[available_numbers.len() - max_slots..] + .iter() + .copied() + .rev() + .collect(); + + let t = tree.clone(); + let p = partials.clone(); + let ha = thread::spawn(move || find_best_mins(&t, &p, &mins[..])); + + let t = tree.clone(); + let p = partials.clone(); + let hb = thread::spawn(move || find_best_maxs(&t, &p, &maxs[..])); + + let best_min_list = ha.join().unwrap(); + let best_max_list = hb.join().unwrap(); + + let t = fill_mins(&tree, &partials, best_min_list.into_iter()); + let t = fill_maxs(&t, &partials, best_max_list.into_iter()); + + println!("{:?}", t); + RandomAgent {}.step_max(board) + } +}