From 993e813b7e5ab233e98a004ec54ae78b10812e76 Mon Sep 17 00:00:00 2001 From: rm-dr <96270320+rm-dr@users.noreply.github.com> Date: Mon, 15 Dec 2025 21:59:14 -0800 Subject: [PATCH] Refactor --- .editorconfig | 13 + Cargo.lock | 801 +++++++++++++------- Cargo.toml | 10 +- README.md | 25 + crates/llmfs/Cargo.toml | 3 + crates/llmfs/src/cli.rs | 13 + crates/llmfs/src/command/download.rs | 11 +- crates/llmfs/src/command/mod.rs | 4 +- crates/llmfs/src/command/sample_data.rs | 735 ------------------ crates/llmfs/src/command/train_model.rs | 312 ++++++++ crates/llmfs/src/command/train_tokenizer.rs | 20 +- crates/llmfs/src/data_reader.rs | 14 +- crates/llmfs/src/logging.rs | 26 +- crates/llmfs/src/main.rs | 68 ++ crates/llmfs/src/parts/attention.rs | 228 ++++++ crates/llmfs/src/parts/mod.rs | 5 + crates/llmfs/src/parts/model.rs | 194 +++++ crates/llmfs/src/train_test_iterator.rs | 164 ++++ crates/tokenizer/src/tokenizer.rs | 3 +- 19 files changed, 1583 insertions(+), 1066 deletions(-) create mode 100644 .editorconfig create mode 100644 README.md delete mode 100644 crates/llmfs/src/command/sample_data.rs create mode 100644 crates/llmfs/src/command/train_model.rs create mode 100644 crates/llmfs/src/parts/attention.rs create mode 100644 crates/llmfs/src/parts/mod.rs create mode 100644 crates/llmfs/src/parts/model.rs create mode 100644 crates/llmfs/src/train_test_iterator.rs diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..78662d1 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,13 @@ +root = true + +[*] +indent_style = tab +indent_size = 4 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.md] +indent_size = 2 +indent_style = space diff --git a/Cargo.lock b/Cargo.lock index 0767f61..c6bf878 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "addr2line" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" +dependencies = [ + "gimli", +] + [[package]] name = "adler2" version = "2.0.1" @@ -273,6 +282,21 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "backtrace" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-link 0.2.1", +] + [[package]] name = "base64" version = "0.22.1" @@ -383,9 +407,8 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "burn" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0291ea5c68786545e239a02f63331cfe39da7485164ae05197d5be6f148d0557" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-autodiff", "burn-candle", @@ -404,60 +427,64 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917423a74bf4d39f17a6799089869648e3d2b6ac89d93901aab4aeb9a7f82138" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", - "burn-tensor", + "burn-backend", + "burn-std", "derive-new", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "num-traits", + "parking_lot", "portable-atomic", "spin", ] [[package]] -name = "burn-candle" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2891811d41ae30b5f1f660e7615b757b2cb4128af5e311b213656de3875e4acb" +name = "burn-backend" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", - "burn-tensor", - "candle-core", + "burn-std", + "bytemuck", + "cubecl", "derive-new", - "half", + "hashbrown 0.16.1", + "num-traits", + "rand", + "rand_distr", + "serde", + "thiserror 2.0.17", ] [[package]] -name = "burn-common" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eb445304e4f91f8633d23c9a5258cd93639d13ce2ee47d4821fd519b683bf02" +name = "burn-candle" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "cubecl-common", - "rayon", - "serde", + "burn-backend", + "burn-std", + "candle-core", + "derive-new", ] [[package]] name = "burn-core" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20c93e754864080a8c27b9a47e3b6f7d79013cf82c9ce00ed57c9ba51a3e34c5" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "ahash", "bincode", - "burn-common", + "burn-dataset", "burn-derive", + "burn-std", "burn-tensor", "data-encoding", "derive-new", "flate2", "half", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "num-traits", "portable-atomic", @@ -473,84 +500,82 @@ dependencies = [ [[package]] name = "burn-cpu" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4807930d243f1aa9dde99db372af56ac532cc6635fd3187156aee375fbadc07" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", - "bytemuck", "cubecl", - "derive-new", - "half", - "log", ] [[package]] name = "burn-cubecl" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd16308b7b0291c77f2d7acf428bc8254ec3db88a430a26cf3d3b0b63ae2d46" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-cubecl-fusion", "burn-fusion", "burn-ir", - "burn-tensor", - "bytemuck", + "burn-std", "cubecl", - "cubecl-quant", + "cubek", "derive-new", "futures-lite", - "half", - "hashbrown 0.15.5", "log", - "num-traits", - "rand", "serde", - "spin", "text_placeholder", ] [[package]] name = "burn-cubecl-fusion" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc21cf88201dfbf242cadb638a0cc924010727fc37d6a719f7e10548b339c63a" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-fusion", "burn-ir", - "burn-tensor", + "burn-std", "cubecl", - "cubecl-quant", + "cubek", "derive-new", - "half", "serde", ] [[package]] name = "burn-cuda" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e104dcf07eac70c7b5864b51d792df3360b11b00febb60543b4283bb414bb61" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", - "bytemuck", "cubecl", +] + +[[package]] +name = "burn-dataset" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" +dependencies = [ + "csv", "derive-new", - "half", - "log", + "dirs", + "rand", + "rmp-serde", + "sanitize-filename", + "serde", + "serde_json", + "strum", + "tempfile", + "thiserror 2.0.17", ] [[package]] name = "burn-derive" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bcf49261de086b8206de6c8962d2adf23feb476119a18e384f5b2c9af07c0cf" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "derive-new", "proc-macro2", @@ -560,16 +585,13 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "662bf2679c04be34a0c3f1b11f77f6ff49456af1620d1eca311bc2562bbb56c9" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-ir", - "burn-tensor", "derive-new", - "half", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "serde", "spin", @@ -577,45 +599,44 @@ dependencies = [ [[package]] name = "burn-ir" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9161239d5691c4ab6f470f2c65aaec5c0a7c1f0b0da390700bcd59f5a77d1d7b" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-tensor", - "hashbrown 0.15.5", + "burn-backend", + "hashbrown 0.16.1", "portable-atomic-util", "serde", ] [[package]] name = "burn-ndarray" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b78bcf4a3508043342f918e796dc79108b5f3252398403eb73952847e7683374" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "atomic_float", "burn-autodiff", - "burn-common", + "burn-backend", "burn-ir", - "burn-tensor", + "burn-std", + "bytemuck", "const-random", - "derive-new", + "itertools 0.14.0", "libm", "macerator", "matrixmultiply", - "ndarray", + "ndarray 0.17.1", "num-traits", "paste", "portable-atomic-util", "rand", - "spin", + "rayon", + "seq-macro", ] [[package]] name = "burn-nn" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc7829c87c4dd6c7929b50fd981e7e8d1b77414323da30ce2067a3e8b7ea422b" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-core", "num-traits", @@ -623,13 +644,12 @@ dependencies = [ [[package]] name = "burn-optim" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31758c02e50247f12457fca1905ed8684ac1b1c5292e10cbbfffb9fa0048d4bd" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-core", "derive-new", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "num-traits", "serde", @@ -637,81 +657,99 @@ dependencies = [ [[package]] name = "burn-rocm" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e1ceb87b6e7349b42d7995477c9a69d0e6c458c64eafa10af3b8b9070f260aa" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", - "bytemuck", "cubecl", - "derive-new", - "half", - "log", ] [[package]] name = "burn-router" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f40403c500b5df380bee47aa0f23032350bdfde5402812d6fcec4d6ff6fbad" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-ir", - "burn-tensor", - "hashbrown 0.15.5", + "burn-std", + "hashbrown 0.16.1", "log", "spin", ] +[[package]] +name = "burn-std" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" +dependencies = [ + "bytemuck", + "bytes", + "cubecl", + "cubecl-common", + "half", + "num-traits", + "serde", +] + [[package]] name = "burn-store" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a2a163486242fcb0c6e2cb89c5a803ab8588673652bb46ecd7af6378d06152f" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-core", "burn-nn", "burn-tensor", "byteorder", + "bytes", "half", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "memmap2", "regex", - "safetensors 0.6.2", + "safetensors 0.7.0", + "textdistance", ] [[package]] name = "burn-tensor" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df8861f7c21d3b07a2b19d028f6eb8903990949708b2ec825559b5200786877c" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", - "bytemuck", + "burn-backend", + "burn-std", "colored", - "cubecl", - "cubecl-quant", "derive-new", - "half", - "hashbrown 0.15.5", "num-traits", - "rand", - "rand_distr", "serde", - "serde_bytes", +] + +[[package]] +name = "burn-train" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" +dependencies = [ + "async-channel", + "burn-core", + "burn-ndarray", + "burn-optim", + "derive-new", + "log", + "rstest", + "serde", + "tracing-appender", + "tracing-core", + "tracing-subscriber", ] [[package]] name = "burn-wgpu" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17aeaa2eadaa4831a64672b99f62ffcdf4874fe4757080633d8a6c4452e2b38" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", "cubecl", ] @@ -746,6 +784,9 @@ name = "bytes" version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +dependencies = [ + "portable-atomic", +] [[package]] name = "candle-core" @@ -1057,6 +1098,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1099,19 +1149,36 @@ dependencies = [ ] [[package]] -name = "cubecl" -version = "0.8.1" +name = "csv" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8b7c74ecaca9356c9ae79d0ebf1db04f02bd98be09eea61f51d73373dffe758" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" dependencies = [ - "cubecl-convolution", + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + +[[package]] +name = "cubecl" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" +dependencies = [ + "cfg_aliases", "cubecl-core", "cubecl-cpu", "cubecl-cuda", "cubecl-hip", - "cubecl-matmul", - "cubecl-random", - "cubecl-reduce", "cubecl-runtime", "cubecl-std", "cubecl-wgpu", @@ -1120,11 +1187,12 @@ dependencies = [ [[package]] name = "cubecl-common" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4556981155bffc057a8effcd4549b52b51df3e9edec43af6ccae2dd03fc8fbff" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ + "backtrace", "bytemuck", + "bytes", "cfg-if", "cfg_aliases", "derive-new", @@ -1152,30 +1220,10 @@ dependencies = [ "web-time", ] -[[package]] -name = "cubecl-convolution" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27c624ec400b7203673bf2db86d7ff30d1384839d497d2dd029c19b1b7371e0d" -dependencies = [ - "bytemuck", - "cubecl-common", - "cubecl-core", - "cubecl-matmul", - "cubecl-random", - "cubecl-reduce", - "cubecl-runtime", - "cubecl-std", - "half", - "pretty_assertions", - "serde", -] - [[package]] name = "cubecl-core" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffc10af538ee74535cda260e581f5a177c243803dd30b698934a515f0114b55" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -1198,9 +1246,8 @@ dependencies = [ [[package]] name = "cubecl-cpp" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d630e4d10cdd3af268ac753914ca79b48f01d1e36c5b5039970a817acc925fea" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", @@ -1215,17 +1262,13 @@ dependencies = [ [[package]] name = "cubecl-cpu" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac1693555277d74152afb61a23e30d1f17d72cebd317a648faf50a8e69380f08" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", - "cubecl-convolution", "cubecl-core", - "cubecl-matmul", "cubecl-opt", - "cubecl-reduce", "cubecl-runtime", "cubecl-std", "derive-new", @@ -1239,9 +1282,8 @@ dependencies = [ [[package]] name = "cubecl-cuda" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67215fcd552a9e8bc68494a71cf2979f2e2bbcbda60f0695f56f86705b89ed5f" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", @@ -1257,16 +1299,14 @@ dependencies = [ [[package]] name = "cubecl-hip" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5e2e6a257f702fb2eb6f24e640e228a94695e4a4c73a4c549578cbb02ad4ec5" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", "cubecl-core", "cubecl-cpp", "cubecl-hip-sys", - "cubecl-quant", "cubecl-runtime", "derive-new", "half", @@ -1287,14 +1327,14 @@ dependencies = [ [[package]] name = "cubecl-ir" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf5d3aa7857e6aee1622aef128d6ad8d9289ed57362b4e65d10cc182aafc585f" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "cubecl-macros-internal", "derive-new", "derive_more", + "enumset", "float-ord", "fnv", "half", @@ -1307,9 +1347,8 @@ dependencies = [ [[package]] name = "cubecl-macros" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5200fb619be424749901e3c6e8e66ae71146c8f83636a74f171bd980cba379d7" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "darling 0.21.3", @@ -1323,9 +1362,8 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a1b673f303396fba18df83368aa4eced474584f1bca34852dccc42bd4ff050c" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "darling 0.21.3", "proc-macro2", @@ -1333,29 +1371,10 @@ dependencies = [ "syn", ] -[[package]] -name = "cubecl-matmul" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cf0a00609a249d5357c27cafea477f35218579db2ab00582d8d5800be4a5a3" -dependencies = [ - "bytemuck", - "cubecl-common", - "cubecl-core", - "cubecl-random", - "cubecl-reduce", - "cubecl-runtime", - "cubecl-std", - "half", - "pretty_assertions", - "serde", -] - [[package]] name = "cubecl-opt" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870ca4b52f9eebd358c9b360b89cdc9f82bde05682db63f0e90c666b3c85a04d" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "cubecl-core", @@ -1369,57 +1388,10 @@ dependencies = [ "type-map", ] -[[package]] -name = "cubecl-quant" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9be3e1202c219078d85dbad7f30d1195fe4f9d42cbfad2c94ab0ea1a6d9f01f6" -dependencies = [ - "cubecl-common", - "cubecl-core", - "cubecl-runtime", - "cubecl-std", - "half", - "serde", -] - -[[package]] -name = "cubecl-random" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a293a05caa68663675823bab66205bca094a21a2c0f6686ad9f20b392516179" -dependencies = [ - "cubecl-common", - "cubecl-core", - "cubecl-runtime", - "cubecl-std", - "half", - "num-traits", - "rand", - "serde", -] - -[[package]] -name = "cubecl-reduce" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53306ace81f6262f7ae794370f47e6b5019842b27e8800240e5b039386b3ac3a" -dependencies = [ - "cubecl-core", - "cubecl-runtime", - "cubecl-std", - "half", - "num-traits", - "pretty_assertions", - "rand", - "serde", -] - [[package]] name = "cubecl-runtime" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91b823bb5899a6fa8809bf7aa36f93f72ced6de58ab9d6edea2c730b235eeda3" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "async-channel", "bytemuck", @@ -1430,7 +1402,7 @@ dependencies = [ "derive-new", "dirs", "enumset", - "foldhash", + "foldhash 0.1.5", "hashbrown 0.15.5", "log", "md5", @@ -1445,15 +1417,15 @@ dependencies = [ [[package]] name = "cubecl-std" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24536998f9fff84f9a1dd2a90f981d5aa4d15eb35cddec5021c4fcf977d2e75e" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "cubecl-core", "cubecl-runtime", - "foldhash", + "foldhash 0.1.5", "half", + "num-traits", "paste", "serde", "spin", @@ -1462,9 +1434,8 @@ dependencies = [ [[package]] name = "cubecl-wgpu" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59a7d737259a784247595e2f0cc5a97d3e50f45cdaefbd4cc7d7fd2126f7a58" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "async-channel", "bytemuck", @@ -1482,11 +1453,103 @@ dependencies = [ "wgpu", ] +[[package]] +name = "cubek" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "cubek-attention", + "cubek-convolution", + "cubek-matmul", + "cubek-quant", + "cubek-random", + "cubek-reduce", +] + +[[package]] +name = "cubek-attention" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-matmul", + "cubek-random", + "half", + "serde", +] + +[[package]] +name = "cubek-convolution" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-matmul", + "derive-new", + "half", + "serde", +] + +[[package]] +name = "cubek-matmul" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-random", + "cubek-reduce", + "half", + "serde", +] + +[[package]] +name = "cubek-quant" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "cubecl-common", + "half", + "serde", +] + +[[package]] +name = "cubek-random" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "cubecl-common", + "half", + "num-traits", + "rand", + "serde", +] + +[[package]] +name = "cubek-reduce" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "half", + "num-traits", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "cudarc" -version = "0.17.8" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf99ab37ee7072d64d906aa2dada9a3422f1d975cdf8c8055a573bc84897ed8" +checksum = "3aa12038120eb13347a6ae2ffab1d34efe78150125108627fd85044dd4d6ff1e" dependencies = [ "libloading", ] @@ -1573,6 +1636,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" +[[package]] +name = "deranged" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive-new" version = "0.7.0" @@ -1624,12 +1696,6 @@ version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "abd57806937c9cc163efc8ea3910e00a62e2aeb0b8119f1793a978088f8f6b04" -[[package]] -name = "diff" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" - [[package]] name = "digest" version = "0.10.7" @@ -1809,6 +1875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25b07a8dfbbbfc0064c0a6bdf9edcf966de6b1c33ce344bdeca3b41615452634" dependencies = [ "enumset_derive", + "serde", ] [[package]] @@ -1955,6 +2022,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -2064,6 +2137,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -2355,6 +2434,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" + [[package]] name = "gl_generator" version = "0.14.0" @@ -2496,7 +2581,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", "serde", ] @@ -2505,6 +2590,13 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", + "serde", + "serde_core", +] [[package]] name = "heck" @@ -3044,15 +3136,18 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" name = "llmfs" version = "0.0.1" dependencies = [ + "ahash", "anstyle", "anyhow", "burn", + "burn-train", "clap", "futures-util", "indicatif", - "ndarray", + "ndarray 0.16.1", "parking_lot", "parquet", + "rand", "rayon", "reqwest", "serde", @@ -3149,7 +3244,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" dependencies = [ "autocfg", + "num_cpus", + "once_cell", "rawpointer", + "thread-tree", ] [[package]] @@ -3302,6 +3400,22 @@ dependencies = [ "serde", ] +[[package]] +name = "ndarray" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7c9125e8f6f10c9da3aad044cc918cf8784fa34de857b1aa68038eb05a50a9" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", + "rayon", +] + [[package]] name = "ndk-sys" version = "0.6.0+11769913" @@ -3379,6 +3493,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -3480,6 +3600,15 @@ dependencies = [ "objc2-core-foundation", ] +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -3689,6 +3818,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -3704,16 +3839,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" -[[package]] -name = "pretty_assertions" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" -dependencies = [ - "diff", - "yansi", -] - [[package]] name = "prettyplease" version = "0.2.37" @@ -3994,6 +4119,12 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "renderdoc-sys" version = "1.1.0" @@ -4084,6 +4215,41 @@ dependencies = [ "serde", ] +[[package]] +name = "rstest" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49" +dependencies = [ + "futures-timer", + "futures-util", + "rstest_macros", +] + +[[package]] +name = "rstest_macros" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -4177,10 +4343,11 @@ dependencies = [ [[package]] name = "safetensors" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" dependencies = [ + "hashbrown 0.16.1", "serde", "serde_json", ] @@ -4632,6 +4799,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "textdistance" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa672c55ab69f787dbc9126cc387dbe57fdd595f585e4524cf89018fa44ab819" + [[package]] name = "thiserror" version = "1.0.69" @@ -4672,6 +4845,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread-tree" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630" +dependencies = [ + "crossbeam-channel", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -4692,6 +4874,37 @@ dependencies = [ "ordered-float 2.10.1", ] +[[package]] +name = "time" +version = "0.3.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" + +[[package]] +name = "time-macros" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -4989,6 +5202,18 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf" +dependencies = [ + "crossbeam-channel", + "thiserror 2.0.17", + "time", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.31" @@ -5997,12 +6222,6 @@ version = "0.8.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ae8337f8a065cfc972643663ea4279e04e7256de865aa66fe25cec5fb912d3f" -[[package]] -name = "yansi" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" - [[package]] name = "yoke" version = "0.7.5" diff --git a/Cargo.toml b/Cargo.toml index 82380cd..0c6d59a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,10 +75,12 @@ compact_str = "0.9.0" dary_heap = "0.3.8" fancy-regex = "0.16.2" indicatif = { version = "0.18.3", features = ["improved_unicode"] } +itertools = "0.14.0" futures-util = "0.3.31" ndarray = { version = "0.16.1", features = ["serde"] } parking_lot = "0.12.5" parquet = "56.2.0" +rand = "0.9.2" rayon = "1.11.0" reqwest = { version = "0.12.24", features = ["json", "stream"] } serde = "1.0.228" @@ -91,7 +93,11 @@ tracing-indicatif = "0.3.13" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } url = "2.5.7" + +burn-train = { git = "https://github.com/tracel-ai/burn.git", default-features = false } + [workspace.dependencies.burn] -version = "0.19.1" +#version = "0.19.1" +git = "https://github.com/tracel-ai/burn.git" default-features = false -features = ["std", "fusion", "ndarray", "webgpu", "cuda"] +features = ["std", "fusion", "ndarray", "webgpu", "cuda", "autodiff"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..1cd5330 --- /dev/null +++ b/README.md @@ -0,0 +1,25 @@ +# LLM from scratch + +## Resources +- [Build a Large Language Model](https://www.manning.com/books/build-a-large-language-model-from-scratch) +- [Writing an LLM from scratch, part 28](https://www.gilesthomas.com/2025/12/llm-from-scratch-28-training-a-base-model-from-scratch) +- [nanochat](https://github.com/karpathy/nanochat) + +## TODO: +- chat cli, evaluate each epoch +- better arch (read nanochat) +- count tokens +- download more data (code, full fineweb) +- better train progress bar +- Notes + +- TrainTestIterator + - total length + - deterministic shuffle + - prepare in parallel + - refactor new() into builder + - small texts (<|bos|>?) + +- Training + - multi-device training + - model parameters in file diff --git a/crates/llmfs/Cargo.toml b/crates/llmfs/Cargo.toml index 69bd819..850317d 100644 --- a/crates/llmfs/Cargo.toml +++ b/crates/llmfs/Cargo.toml @@ -10,15 +10,18 @@ workspace = true [dependencies] tokenizer = { workspace = true } +ahash = { workspace = true } anstyle = { workspace = true } anyhow = { workspace = true } burn = { workspace = true } +burn-train = { workspace = true } clap = { workspace = true } futures-util = { workspace = true } indicatif = { workspace = true } ndarray = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true } +rand = { workspace = true } rayon = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } diff --git a/crates/llmfs/src/cli.rs b/crates/llmfs/src/cli.rs index 01d17bc..cf6cda2 100644 --- a/crates/llmfs/src/cli.rs +++ b/crates/llmfs/src/cli.rs @@ -62,3 +62,16 @@ pub fn progress_bytes() -> ProgressStyle { "⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹", ]); } + +#[expect(clippy::unwrap_used)] +pub fn progress_persec() -> ProgressStyle { + return ProgressStyle::default_bar() + .template( + " {bar:16.red/white.dim} {elapsed_precise:.dim} {pos}/{len} ({per_sec:>3}) {msg:.dim} ({eta})", + ) + .unwrap() + .progress_chars("---") + .tick_strings(&[ + "⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹", + ]); +} diff --git a/crates/llmfs/src/command/download.rs b/crates/llmfs/src/command/download.rs index 9674cb4..72f1f10 100644 --- a/crates/llmfs/src/command/download.rs +++ b/crates/llmfs/src/command/download.rs @@ -21,9 +21,8 @@ const MAX_SHARD: usize = 1822; #[derive(Debug, Args, Clone)] pub struct DownloadArgs { - /// Training data dir - #[clap(default_value = "data")] - data_dir: PathBuf, + /// Training data directory (will be created) + data: PathBuf, /// Number of shards to download (-1 for all) #[arg(short = 'n', long, default_value = "-1")] @@ -37,7 +36,7 @@ pub struct DownloadArgs { impl DownloadArgs { pub fn run(self, mp: Option) -> Result<()> { info!("Downloading files from {BASE_URL}"); - fs::create_dir_all(&self.data_dir)?; + fs::create_dir_all(&self.data)?; let num_shards_to_download = if self.num_files == -1 { MAX_SHARD + 1 @@ -48,7 +47,7 @@ impl DownloadArgs { let ids_to_download: Vec = (0..num_shards_to_download).collect(); info!("Downloading {} shards...", ids_to_download.len(),); - info!("Target directory: {}", self.data_dir.display()); + info!("Target directory: {}", self.data.display()); let main_pb = mp.as_ref().map(|mp| { let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64)); @@ -70,7 +69,7 @@ impl DownloadArgs { ids_to_download .into_par_iter() .for_each_with(tx, |tx, index| { - let target = self.data_dir.clone(); + let target = self.data.clone(); let main_pb = main_pb.clone(); let mp_clone = mp.clone(); let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread diff --git a/crates/llmfs/src/command/mod.rs b/crates/llmfs/src/command/mod.rs index 7757d55..b2e0604 100644 --- a/crates/llmfs/src/command/mod.rs +++ b/crates/llmfs/src/command/mod.rs @@ -1,5 +1,5 @@ mod download; -mod sample_data; +mod train_model; mod train_tokenizer; #[derive(Debug, clap::Subcommand)] @@ -19,7 +19,7 @@ pub enum SubCommand { /// Train model TrainModel { #[command(flatten)] - args: sample_data::TrainModelArgs, + args: train_model::TrainModelArgs, }, } diff --git a/crates/llmfs/src/command/sample_data.rs b/crates/llmfs/src/command/sample_data.rs deleted file mode 100644 index 5ae5b13..0000000 --- a/crates/llmfs/src/command/sample_data.rs +++ /dev/null @@ -1,735 +0,0 @@ -use ahash::AHasher; -use anyhow::{Context, Result}; -use burn::{ - Tensor, - backend::{Autodiff, Cuda, cuda::CudaDevice}, - config::Config, - module::{AutodiffModule, Module, Param, ParamId}, - nn::{ - Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, - loss::CrossEntropyLossConfig, - transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}, - }, - optim::{AdamConfig, GradientsParams, Optimizer}, - prelude::{Backend, ToElement}, - tensor::{Bool, Distribution, Int, activation::softmax}, -}; -use burn_train::ClassificationOutput; -use clap::Args; -use indicatif::MultiProgress; -use ndarray::{Array1, Array2}; -use std::{ - collections::VecDeque, - f32, - fs::File, - hash::Hasher, - path::{Path, PathBuf}, -}; -use tokenizer::Tokenizer; -use tracing::{debug, info}; - -use crate::data_reader::{DataReader, DataReaderError}; - -// Text generation routine - -/* -{ - let init = "Initial context. This is "; - let tokens = tokenizer.encode(&init); - - let n_tokens = tokens.len(); - let input: Array1 = Array1::from_vec(tokens); - let mut input: Tensor = - Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) - .reshape([n_tokens]); - - for _ in 0..100 { - let tokens: Vec = input.clone().to_data().convert::().into_vec().unwrap(); - println!("{:?}", tokens); - println!("{}", tokenizer.decode(&tokens)); - - // Crop idx to context size; - let batch = input - .clone() - .slice([0..config.context_size]) - .unsqueeze_dim(0); - - // shape: [tokens, vocab_size] - let logits = model.forward(batch).squeeze_dim::<2>(0); - - // shape: [vocab_size] - let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0); - - let probs = softmax(logits, 0); // shape: [n_tokens] - let id_next = probs.argmax(0); // shape: [1] - - input = Tensor::cat(vec![input.slice([1..]), id_next], 0); - } -} -*/ - -struct TrainTestIterator<'a, B: Backend> { - reader: DataReader, - - ccfg: &'a ComputeConfig, - mcfg: &'a GptModelConfig, - tokenizer: &'a Tokenizer, - eval: bool, - device: &'a B::Device, - - error: bool, - - // Tokenized input/output pairs - pairs: VecDeque<(Vec, u32)>, -} - -impl<'a, B: Backend> TrainTestIterator<'a, B> { - pub fn new( - data_dir: impl AsRef, - ccfg: &'a ComputeConfig, - mcfg: &'a GptModelConfig, - tokenizer: &'a Tokenizer, - eval: bool, - device: &'a B::Device, - ) -> Result { - let reader = DataReader::new(10, data_dir)?; - - Ok(Self { - reader, - ccfg, - mcfg, - tokenizer, - eval, - device, - - error: false, - pairs: VecDeque::new(), - }) - } -} - -impl Iterator for TrainTestIterator<'_, B> { - type Item = Result, DataReaderError>; - - fn next(&mut self) -> Option { - if self.error { - return None; - } - - let mut inputs = Vec::with_capacity(self.ccfg.batch_size); - let mut targets = Vec::with_capacity(self.ccfg.batch_size); - let stride = self.mcfg.context_size; - - while inputs.len() < self.ccfg.batch_size { - match self.pairs.pop_front() { - Some((i, t)) => { - // train/test split - { - let mut hasher = AHasher::default(); - hasher.write(self.ccfg.eval_salt.as_bytes()); - - // Don't care about endianness, ahash output is unstable anyway - hasher.write(unsafe { std::mem::transmute(&i[..]) }); - hasher.write_u32(t); - - let test = // is this point in the test set? - hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64(); - - if test ^ self.eval { - continue; - } - } - - inputs.push(i); - targets.push(t); - } - - None => { - let text = match self.reader.next() { - None => break, - Some(Ok(x)) => x, - Some(Err(x)) => { - self.error = true; - return Some(Err(x)); - } - }; - - let emb = self.tokenizer.encode(&text); - - // Skip small texts - // - // TODO: do this better - // TODO: maybe using <|bos|>? - // TODO: non-uniform batches? - if emb.len() < self.mcfg.context_size { - continue; - } - - let pairs = emb - .windows(self.mcfg.context_size + 1) - .step_by(stride) - .map(|x| { - ( - x[..self.mcfg.context_size].to_vec(), - x[self.mcfg.context_size], - ) - }); - - self.pairs.extend(pairs); - } - } - } - - if inputs.is_empty() { - return None; - } - - let shape = [inputs.len(), self.mcfg.context_size]; - - // Arrange data in memory - let inputs: Array2 = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]); - let targets: Array1 = Array1::from_vec(targets); - - // Create tensors on gpu - #[expect(clippy::unwrap_used)] - let inputs = - Tensor::::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape); - - #[expect(clippy::unwrap_used)] - let targets = Tensor::::from_ints(targets.as_slice().unwrap(), self.device); - - return Some(Ok(TrainBatch { inputs, targets })); - } -} - -#[derive(Debug, Args, Clone)] - -pub struct TrainModelArgs { - /// Path to training data - data: PathBuf, - - /// Path to tokenizer - #[clap(long)] - tokenizer: PathBuf, -} - -pub struct ComputeConfig { - pub batch_size: usize, - pub eval_frac: f64, - pub eval_salt: String, -} - -impl TrainModelArgs { - pub fn run(self, _mp: Option) -> Result<()> { - let device = CudaDevice::new(0); - //let device = WgpuDevice::DiscreteGpu(0); - - let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?; - let tokenizer: Tokenizer = - serde_json::from_reader(tokenizer).context("while loading tokenizer")?; - - let ccfg = ComputeConfig { - batch_size: 10, - eval_frac: 0.1, - eval_salt: "salt".into(), - }; - - let mcfg = GptModelConfig { - vocab_size: tokenizer.vocab_size(), - context_size: 256, - embed_dim: 768, - n_heads: 12, - head_dim: 64, // = 768 / 12 - n_layers: 1, - embed_drop_rate: 0.1, - attention_drop_rate: 0.1, - shortcut_drop_rate: 0.1, - }; - - let mut model: GptModel> = mcfg.init(&device); - - /* - let loader_train = DataLoaderBuilder::new(batcher.clone()) - .batch_size(ccfg.batch_size) - //.shuffle(config.seed) - .num_workers(5) - .build(Loader::new(&self.data_dir).context("while initializing loader")?); - - let loader_test = DataLoaderBuilder::new(batcher) - .batch_size(ccfg.batch_size) - //.shuffle(config.seed) - .num_workers(5) - .build(Loader::new(&self.data_dir).context("while initializing loader")?); - - let learner = LearnerBuilder::new("./tmp") - .metric_train_numeric(AccuracyMetric::new()) - .metric_valid_numeric(AccuracyMetric::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) - .with_file_checkpointer(CompactRecorder::new()) - .learning_strategy(LearningStrategy::SingleDevice(device.clone())) - .num_epochs(10) - .summary() - .build(model, AdamConfig::new().init(), 1e-4); - - learner.fit(loader_train, loader_test); - */ - - // Initialize optimizer - let mut optim = AdamConfig::new().init(); - let learning_rate = 1e-4; - - for epoch in 0..10 { - debug!("Running epoch {epoch}"); - - // Training phase - let mut train_loss_sum = 0.0; - let mut train_total = 0; - - for batch in - TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, false, &device) - .context("while initializing reader")? - { - let batch = batch.context("while reading batch")?; - - // Forward pass with gradients - let output = model.forward_train(batch.inputs, batch.targets); - - train_total += output.targets.dims()[0] as i32; - train_loss_sum += output.loss.clone().into_scalar().to_f32(); - - let grads = output.loss.backward(); - let grads = GradientsParams::from_grads(grads, &model); - - model = optim.step(learning_rate, model, grads); - } - - let mut valid_loss_sum = 0.0; - let mut valid_total = 0; - - let mut n_eval = 0; - debug!("Evaluating batches"); - - for batch in TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, true, &device) - .context("while initializing reader")? - { - let batch = batch.context("while reading batch")?; - n_eval += batch.targets.shape()[0]; - - // Forward pass without gradients - let output = model.valid().forward_train(batch.inputs, batch.targets); - - valid_total += output.targets.dims()[0] as i32; - valid_loss_sum += output.loss.into_scalar().to_f32(); - } - - // Compute and log epoch results - let train_loss = if train_total > 0 { - train_loss_sum / train_total as f32 - } else { - 0.0 - }; - let valid_loss = if valid_total > 0 { - valid_loss_sum / valid_total as f32 - } else { - 0.0 - }; - - info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval); - } - - Ok(()) - } -} - -// -// MARK: model -// - -/// Multihead attention. -/// -/// Equivalent to many stacked CausalAttention layers. -/// These are packed inside one big tensor for efficiency. -#[derive(Module, Debug)] -pub struct MultiheadAttention { - n_heads: usize, - head_dim: usize, - - // Can also use Linear layers with disabled bias - // (they may also have a better initialization routine) - // TODO: see source code, make this equivalent - /// Query weight matrices for each head, stacked on the last dimension. - /// (so that shape is [tokens, n_heads * head_dim]) - /// - /// Intuitively, this learns "what question to ask about the text" - /// for a given query token. (e.g, "it" -> what does "it" refer to?) - w_query: Param>, - - /// Key weight matrices for each head, stacked on the last dimension. - /// (so that shape is [tokens, n_heads * head_dim]) - /// - /// Intuitively, this learns what properties a certain token - /// has when it appears as a context (non-query) token. - w_key: Param>, - - /// Value weight matrices for each head, stacked on the last dimension. - /// (so that shape is [tokens, n_heads * head_dim]) - /// - /// Intuitively, ??? - w_value: Param>, - - /// Optional final projection. - /// Maps [total_dim, total_dim] to [total_dim, total_dim] - w_output: Param>, - - dropout: Dropout, - - /// Upper-triangular matrix of ones, excluding diagonal. - /// Used to mask future tokens. - utri_mask: Tensor, -} - -impl MultiheadAttention { - pub fn new( - embedding_dim: usize, - head_dim: usize, - n_heads: usize, - context_length: usize, - dropout: f64, - device: &B::Device, - ) -> Self { - let total_dim = head_dim * n_heads; - - Self { - n_heads, - head_dim, - - w_query: Param::uninitialized( - ParamId::new(), - move |device, is_require_grad| { - Tensor::random([embedding_dim, total_dim], Distribution::Default, device) - .set_require_grad(is_require_grad) - }, - device.clone(), - true, - [embedding_dim, total_dim].into(), - ), - - w_key: Param::uninitialized( - ParamId::new(), - move |device, is_require_grad| { - Tensor::random([embedding_dim, total_dim], Distribution::Default, device) - .set_require_grad(is_require_grad) - }, - device.clone(), - true, - [embedding_dim, total_dim].into(), - ), - - w_value: Param::uninitialized( - ParamId::new(), - move |device, is_require_grad| { - Tensor::random([embedding_dim, total_dim], Distribution::Default, device) - .set_require_grad(is_require_grad) - }, - device.clone(), - true, - [embedding_dim, total_dim].into(), - ), - - w_output: Param::uninitialized( - ParamId::new(), - move |device, is_require_grad| { - Tensor::random([total_dim, total_dim], Distribution::Default, device) - .set_require_grad(is_require_grad) - }, - device.clone(), - true, - [total_dim, total_dim].into(), - ), - - dropout: Dropout { prob: dropout }, - - utri_mask: Tensor::::tril_mask([context_length, context_length], 0, device), - } - } - - /// Compute self-attention vector for the given batch - /// - /// - input shape is [batch, token, token_dim] - /// - input shape is [batch, token, n_heads * head_dim] - pub fn forward(&self, input: Tensor) -> Tensor { - // Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok) - // But adds an "inner latent space" using Wq, Qk, and Wv. - // - // Multiple heads are batched into one tensor. - - let batch = input.dims()[0]; - let tokens = input.dims()[1]; - - let w_query = self - .w_query - .val() - .unsqueeze_dim::<3>(0) - .expand([batch as i64, -1, -1]); - - let w_key = self - .w_key - .val() - .unsqueeze_dim::<3>(0) - .expand([batch as i64, -1, -1]); - - let w_value = self - .w_value - .val() - .unsqueeze_dim::<3>(0) - .expand([batch as i64, -1, -1]); - - let w_output = self - .w_output - .val() - .unsqueeze_dim::<3>(0) - .expand([batch as i64, -1, -1]); - - // Map batch to inner latent space. - // shape: [batch, token, inner_dim] - let queries = input.clone().matmul(w_query); - let keys = input.clone().matmul(w_key); - let values = input.clone().matmul(w_value); - - // Split head dimensions - let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]); - let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]); - let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]); - - // from: [batch, tok, head, head_dim] - // to: [batch, head, tok, head_dim] - let keys = keys.swap_dims(1, 2); - let values = values.swap_dims(1, 2); - let queries = queries.swap_dims(1, 2); - - // Compute attention scores for each head - // (cosine similarity of each query token to each context token, per head) - // - // lhs shape: [batch, head, tok, head_dim] - // rhs shape: [batch, head, head_dim, tok] - // output shape: [batch, head, query_token, context_token] - let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3)); - - let mask = self - .utri_mask - .clone() - .slice([0..tokens, 0..tokens]) - .unsqueeze_dim::<3>(0) - .unsqueeze_dim::<4>(0) - .expand(attn_scores.shape()); - - // Mask out future tokens by filling - // upper-triangular with -inf, which becomes 0.0 after softmax. - let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY); - - // Normalize attn weights. - // - // Divide by sqrt(inner_dim) because... - // - dot products get larger with larger dimensions - // - this causes softmax to "saturate", making all other values very small - // - which makes gradients vanish during training - let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3); - let attn_weights = self.dropout.forward(attn_weights); - - // lhs shape: [batch, head, query_token, context_token] - // rhs shape: [batch, head, tok, head_dim] - // matmul shape: [batch, head, tok, head_dim] - // out shape: [batch, tok, head, head_dim] - let context_vec = attn_weights.matmul(values).swap_dims(1, 2); - - // shape: [batch, tok, stacked_dim] - let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]); - - // Apply final projection (optional) - let context_vec = context_vec.matmul(w_output); - - return context_vec; - } -} - -#[derive(Config, Debug)] -pub struct GptModelConfig { - /// Number of tokens - pub vocab_size: u32, - - /// Maximum number of input tokens with positional embeddings - pub context_size: usize, - - /// Dimension of each token's embedding - pub embed_dim: usize, - - /// Number of attention heads - pub n_heads: usize, - - /// Dimension of each attn head - pub head_dim: usize, - - /// Number of transformer blocks - pub n_layers: usize, - - pub embed_drop_rate: f64, - pub attention_drop_rate: f64, - pub shortcut_drop_rate: f64, -} - -impl GptModelConfig { - pub fn init(&self, device: &B::Device) -> GptModel { - let out_head_shape = [self.embed_dim, self.vocab_size as usize]; - - GptModel { - embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim) - .init(device), - - embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device), - - embedder_drop: Dropout { - prob: self.embed_drop_rate, - }, - - trf_blocks: (0..self.n_layers) - .map(|_| TransformerBlock::new(&self, device)) - .collect(), - - final_norm: LayerNormConfig::new(self.embed_dim).init(device), - - out_head: Param::uninitialized( - ParamId::new(), - move |device, is_require_grad| { - Tensor::random(out_head_shape, Distribution::Default, device) - .set_require_grad(is_require_grad) - }, - device.clone(), - true, - out_head_shape.into(), - ), - } - } -} - -#[derive(Debug, Clone)] -pub struct TrainBatch { - pub inputs: Tensor, - - /// Correct next token for each input - pub targets: Tensor, -} - -#[derive(Module, Debug)] -pub struct GptModel { - embedder_tok: Embedding, - embedder_pos: Embedding, - embedder_drop: Dropout, - - trf_blocks: Vec>, - final_norm: LayerNorm, - out_head: Param>, -} - -impl GptModel { - pub fn forward(&self, input: Tensor) -> Tensor { - let n_tokens = input.shape()[1]; - - let embed_tok = self.embedder_tok.forward(input.clone()); - let embed_pos = self - .embedder_pos - .forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0)); - - let x = embed_tok + embed_pos; - let x = self.embedder_drop.forward(x); - let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x)); - let x = self.final_norm.forward(x); - let logits = x.matmul(self.out_head.val().unsqueeze_dim(0)); - - return logits; - } - - pub fn forward_train( - &self, - inputs: Tensor, - targets: Tensor, - ) -> ClassificationOutput { - // shape: [batch, n_tokens, n_vocabulary] - let output = self.forward(inputs); - - // Get last token - // shape: [batch, n_vocabulary] - let output = output.slice_dim(1, -1).squeeze_dim::<2>(1); - - let loss = CrossEntropyLossConfig::new() - .init(&targets.device()) - .forward(output.clone(), targets.clone()); - - ClassificationOutput { - loss, - output, - targets, - } - } -} - -#[derive(Module, Debug)] -pub struct TransformerBlock { - attention: MultiheadAttention, - - /// TODO: wtf? - ff: PositionWiseFeedForward, - - /// TODO: wtf? - norm_a: LayerNorm, - norm_b: LayerNorm, - - drop_shortcut: Dropout, -} - -impl TransformerBlock { - pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self { - Self { - attention: MultiheadAttention::new( - cfg.embed_dim, - cfg.head_dim, - cfg.n_heads, - cfg.context_size, - cfg.attention_drop_rate, - device, - ), - - ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim) - .with_dropout(0.0) - .init(device), - - norm_a: LayerNormConfig::new(cfg.embed_dim).init(device), - norm_b: LayerNormConfig::new(cfg.embed_dim).init(device), - - drop_shortcut: Dropout { - prob: cfg.shortcut_drop_rate, - }, - } - } - - pub fn forward(&self, input: Tensor) -> Tensor { - let input = { - let shortcut = input.clone(); - let x = self.norm_a.forward(input); - let x = self.attention.forward(x); - let x = self.drop_shortcut.forward(x); - x + shortcut - }; - - let input = { - // TODO: wtf? - let shortcut = input.clone(); - let x = self.norm_b.forward(input); - let x = self.ff.forward(x); - let x = self.drop_shortcut.forward(x); - x + shortcut - }; - - return input; - } -} diff --git a/crates/llmfs/src/command/train_model.rs b/crates/llmfs/src/command/train_model.rs new file mode 100644 index 0000000..a3c3058 --- /dev/null +++ b/crates/llmfs/src/command/train_model.rs @@ -0,0 +1,312 @@ +use anyhow::{Context, Result}; +use burn::{ + backend::Autodiff, + module::{AutodiffModule, Module}, + optim::{AdamConfig, GradientsParams, Optimizer}, + prelude::ToElement, + record::{FullPrecisionSettings, NamedMpkFileRecorder}, + tensor::backend::AutodiffBackend, +}; +use clap::Args; +use indicatif::{MultiProgress, ProgressBar}; +use std::{ + f32, + fs::File, + num::NonZero, + path::PathBuf, + time::{Duration, Instant}, +}; +use tokenizer::Tokenizer; +use tracing::{debug, info}; + +use crate::{ + InferenceDevice, + cli::{progress_big, progress_persec}, + parts::{GptModel, GptModelConfig}, + train_test_iterator::TrainTestIterator, +}; + +// Text generation routine +/* +{ + let init = "Initial context. This is "; + let tokens = tokenizer.encode(&init); + + let n_tokens = tokens.len(); + let input: Array1 = Array1::from_vec(tokens); + let mut input: Tensor = + Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) + .reshape([n_tokens]); + + for _ in 0..100 { + let tokens: Vec = input.clone().to_data().convert::().into_vec().unwrap(); + println!("{:?}", tokens); + println!("{}", tokenizer.decode(&tokens)); + + // Crop idx to context size; + let batch = input + .clone() + .slice([0..config.context_size]) + .unsqueeze_dim(0); + + // shape: [tokens, vocab_size] + let logits = model.forward(batch).squeeze_dim::<2>(0); + + // shape: [vocab_size] + let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0); + + let probs = softmax(logits, 0); // shape: [n_tokens] + let id_next = probs.argmax(0); // shape: [1] + + input = Tensor::cat(vec![input.slice([1..]), id_next], 0); + } +} +*/ + +#[derive(Debug, Args, Clone)] +pub struct TrainModelArgs { + /// Path to training data + data: PathBuf, + + /// Path to tokenizer + #[clap(long, default_value = "tokenizer.json")] + tokenizer: PathBuf, + + /// directory to save checkpoints + #[clap(long, default_value = "checkpoints")] + checkpoints: PathBuf, + + /// The device to use for compute. `wgpu:n`, `cuda:n`, or `cpu` + #[clap(long, default_value = "cpu")] + device: InferenceDevice, + + /// Training batch size + #[clap(long, default_value = "10")] + batch: NonZero, + + /// Proportion of data reserved for evaluation + #[clap(long, default_value = "0.1")] + eval_frac: f64, + + /// Eval hasher salt + #[clap(long, default_value = "eval-salt")] + eval_salt: String, + + /// Number of threads reading data + #[clap(long, default_value = "5")] + readers: usize, +} + +pub struct ComputeConfig { + pub batch_size: usize, + pub eval_frac: f64, + pub eval_salt: String, +} + +impl TrainModelArgs { + pub fn run(self, mp: Option) -> Result<()> { + match self.device { + InferenceDevice::Cpu => { + use burn::backend::NdArray; + use burn::backend::ndarray::NdArrayDevice; + + let device = NdArrayDevice::Cpu; + self.run_inner::>(mp, device)?; + } + + InferenceDevice::Cuda(x) => { + use burn::backend::Cuda; + use burn::backend::cuda::CudaDevice; + + let device = CudaDevice::new(x); + self.run_inner::>(mp, device)?; + } + + InferenceDevice::Wgpu(x) => { + use burn::backend::Wgpu; + use burn::backend::wgpu::WgpuDevice; + + let device = WgpuDevice::DiscreteGpu(x); + self.run_inner::>(mp, device)?; + } + }; + + return Ok(()); + } + + fn run_inner( + self, + mp: Option, + device: B::Device, + ) -> Result<()> { + let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?; + let tokenizer: Tokenizer = + serde_json::from_reader(tokenizer).context("while loading tokenizer")?; + + let ccfg = ComputeConfig { + batch_size: self.batch.get(), + eval_frac: self.eval_frac, + eval_salt: self.eval_salt.clone(), + }; + + let mcfg = GptModelConfig { + vocab_size: tokenizer.vocab_size(), + context_size: 256, // TODO: MORE! + embed_dim: 768, + n_heads: 12, + head_dim: 64, // = 768 / 12 + n_layers: 12, + embed_drop_rate: 0.1, + attention_drop_rate: 0.1, + shortcut_drop_rate: 0.1, + }; + + let mut model: GptModel = mcfg.init(&device); + + let mut optim = AdamConfig::new().init(); + let learning_rate = 1e-4; + + std::fs::create_dir_all(&self.checkpoints).context("while creating checkpoint dir")?; + let recorder = NamedMpkFileRecorder::::new(); + + let main_pb = mp.as_ref().map(|mp| { + let pb = mp.add(ProgressBar::new(10 as u64)); + pb.set_style(progress_big()); + pb.set_message("Training model"); + pb.enable_steady_tick(Duration::from_millis(100)); + pb + }); + + for epoch in 0..10 { + let start = Instant::now(); + debug!("Running epoch {epoch}"); + + let epoch_pb = mp.as_ref().map(|mp| { + let pb = mp.add(ProgressBar::no_length()); + pb.set_style(progress_persec()); + pb.set_message(format!("Running epoch {epoch}")); + pb.enable_steady_tick(Duration::from_millis(100)); + pb + }); + + // Training phase + let mut train_loss_sum = 0.0; + let mut train_total = 0; + + let mut n_train = 0u64; + for batch in TrainTestIterator::new( + &self.data, + &tokenizer, + false, + ccfg.batch_size, + mcfg.context_size, + ccfg.eval_frac, + &ccfg.eval_salt, + self.readers, + &device, + ) + .context("while initializing reader")? + { + let batch = batch.context("while reading batch")?; + epoch_pb.as_ref().map(|x| x.set_position(n_train)); + n_train += batch.inputs.shape()[0] as u64; + + // Forward pass with gradients + let output = model.forward_train(batch.inputs, batch.targets); + + train_total += output.targets.dims()[0] as i32; + train_loss_sum += output.loss.clone().into_scalar().to_f32(); + + let grads = output.loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + + model = optim.step(learning_rate, model, grads); + } + + epoch_pb.map(|x| x.finish_and_clear()); + + let mut valid_loss_sum = 0.0; + let mut valid_total = 0; + + let mut n_eval = 0; + debug!("Evaluating batches"); + + let eval_pb = mp.as_ref().map(|mp| { + let pb = mp.add(ProgressBar::no_length()); + pb.set_style(progress_persec()); + pb.set_message(format!("Evaluating epoch {epoch}")); + pb.enable_steady_tick(Duration::from_millis(100)); + pb + }); + + for batch in TrainTestIterator::new( + &self.data, + &tokenizer, + true, + ccfg.batch_size, + mcfg.context_size, + ccfg.eval_frac, + &ccfg.eval_salt, + self.readers, + &device, + ) + .context("while initializing reader")? + { + let batch = batch.context("while reading batch")?; + eval_pb.as_ref().map(|x| x.set_position(n_eval)); + n_eval += batch.inputs.shape()[0] as u64; + + // Forward pass without gradients + let output = model.valid().forward_train(batch.inputs, batch.targets); + + valid_total += output.targets.dims()[0] as i32; + valid_loss_sum += output.loss.into_scalar().to_f32(); + } + + eval_pb.map(|x| x.finish_and_clear()); + + // Compute and log epoch results + let train_loss = if train_total > 0 { + train_loss_sum / train_total as f32 + } else { + 0.0 + }; + let valid_loss = if valid_total > 0 { + valid_loss_sum / valid_total as f32 + } else { + 0.0 + }; + + info!( + message = "Ran epoch", + epoch, + train_loss, + valid_loss, + n_train, + n_eval, + time_ms = start.elapsed().as_millis() + ); + main_pb.as_ref().map(|x| x.inc(1)); + + { + let target = self.checkpoints.join(format!("epoch-{epoch:02}")); + + info!(message = "Saving checkpoint", ?target); + std::fs::create_dir_all(&self.checkpoints) + .context("while creating checkpoint dir")?; + + model + .clone() + .save_file(target, &recorder) + .context("while saving checkpoint")?; + } + } + + if let Some(pb) = main_pb.as_ref() { + pb.finish_and_clear(); + info!("Training complete"); + } + + Ok(()) + } +} diff --git a/crates/llmfs/src/command/train_tokenizer.rs b/crates/llmfs/src/command/train_tokenizer.rs index b91c125..6f902b1 100644 --- a/crates/llmfs/src/command/train_tokenizer.rs +++ b/crates/llmfs/src/command/train_tokenizer.rs @@ -12,22 +12,25 @@ use crate::data_reader::DataReader; #[derive(Debug, Args, Clone)] pub struct TrainTokenizerArgs { - /// Where to save tokenizer - #[clap(default_value = "tokenizer.json")] - target: PathBuf, - /// Path to training data - #[clap(long, default_value = "data")] - data_dir: PathBuf, + data: PathBuf, + + /// Where to save tokenizer + #[clap(long, default_value = "tokenizer.json")] + target: PathBuf, /// Only train on the first n texts #[clap(long)] first_n: Option, - /// Number of threads to use for training + /// Number of threads to use for training. 0 to autodetect. #[clap(long, default_value = "0")] threads: usize, + /// Number of threads reading data + #[clap(long, default_value = "5")] + readers: usize, + /// Tokenizer vocabulary size #[clap(long, default_value = "65535")] n_tokens: u32, @@ -35,7 +38,8 @@ pub struct TrainTokenizerArgs { impl TrainTokenizerArgs { pub fn run(self, mp: Option) -> Result<()> { - let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?; + let iter = DataReader::new(self.readers.max(1), &self.data) + .context("while initializing data reader")?; #[expect(clippy::unwrap_used)] // Lazy error handling let iter = iter.map(|x| x.unwrap()); diff --git a/crates/llmfs/src/data_reader.rs b/crates/llmfs/src/data_reader.rs index 7e1188d..5a577d4 100644 --- a/crates/llmfs/src/data_reader.rs +++ b/crates/llmfs/src/data_reader.rs @@ -3,6 +3,7 @@ use parking_lot::Mutex; use parquet::errors::ParquetError; use parquet::file::reader::{FileReader, SerializedFileReader}; use parquet::record::RowAccessor; +use rand::seq::SliceRandom; use std::fs::File; use std::path::Path; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -25,10 +26,11 @@ pub enum DataReaderError { /// /// All parquet files have exactly one text column. /// No promises about this struct's behavior if this is not the case. +#[derive(Clone)] pub struct DataReader { - rx: Receiver>, + rx: Arc>>>, total_rows: usize, - consumed_rows: AtomicUsize, + consumed_rows: Arc, } impl DataReader { @@ -57,6 +59,8 @@ impl DataReader { files.push(path); } } + + files.shuffle(&mut rand::rng()); (Arc::new(Mutex::new(files)), total_rows) }; @@ -147,9 +151,9 @@ impl DataReader { }); Ok(Self { - rx, + rx: Arc::new(Mutex::new(rx)), total_rows, - consumed_rows: AtomicUsize::new(0), + consumed_rows: Arc::new(AtomicUsize::new(0)), }) } @@ -157,7 +161,7 @@ impl DataReader { /// Order is arbitrary. /// Returns `None` when all rows have been read. pub fn recv(&self) -> Option> { - self.rx.recv().ok() + self.rx.lock().recv().ok() } //pub fn try_recv(&self) -> Result, TryRecvError> { diff --git a/crates/llmfs/src/logging.rs b/crates/llmfs/src/logging.rs index 7660b9e..5b3ca27 100644 --- a/crates/llmfs/src/logging.rs +++ b/crates/llmfs/src/logging.rs @@ -12,18 +12,16 @@ use tracing_subscriber::{ // MARK: loglevel // -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)] -#[derive(Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum, Default)] pub enum LogLevel { Trace, Debug, #[default] - Info, + Info, Warn, Error, } - impl Display for LogLevel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -47,7 +45,7 @@ pub struct LoggingConfig { pub silence: LogLevel, // Bins - pub nanochat: LogLevel, + pub llmfs: LogLevel, } impl From for EnvFilter { @@ -71,7 +69,7 @@ impl From for EnvFilter { // // Bins // - format!("nanochat_rs={}", conf.nanochat), + format!("llmfs={}", conf.llmfs), conf.other.to_string(), ] .join(","), @@ -164,31 +162,31 @@ impl LogFilterPreset { Self::Error => LoggingConfig { other: LogLevel::Error, silence: LogLevel::Error, - nanochat: LogLevel::Error, + llmfs: LogLevel::Error, }, Self::Warn => LoggingConfig { other: LogLevel::Warn, silence: LogLevel::Warn, - nanochat: LogLevel::Warn, + llmfs: LogLevel::Warn, }, Self::Info => LoggingConfig { other: LogLevel::Warn, silence: LogLevel::Warn, - nanochat: LogLevel::Info, + llmfs: LogLevel::Info, }, Self::Debug => LoggingConfig { other: LogLevel::Warn, silence: LogLevel::Warn, - nanochat: LogLevel::Debug, + llmfs: LogLevel::Debug, }, Self::Trace => LoggingConfig { other: LogLevel::Trace, silence: LogLevel::Warn, - nanochat: LogLevel::Trace, + llmfs: LogLevel::Trace, }, } } @@ -216,16 +214,14 @@ pub enum LoggingTarget { } /// How to print logs -#[derive(Debug, Clone, Copy, Deserialize)] -#[derive(Default)] +#[derive(Debug, Clone, Copy, Deserialize, Default)] pub enum LoggingFormat { #[default] - Ansi, + Ansi, AnsiNoColor, Json, } - pub struct LoggingInitializer { /// Log filter for printed logs pub preset: LogFilterPreset, diff --git a/crates/llmfs/src/main.rs b/crates/llmfs/src/main.rs index 2297c77..4b0da1a 100644 --- a/crates/llmfs/src/main.rs +++ b/crates/llmfs/src/main.rs @@ -1,5 +1,8 @@ +#![recursion_limit = "256"] // needed to resolve burn types + use clap::Parser; use indicatif::MultiProgress; +use serde::{Deserialize, Deserializer}; use tracing::error; use crate::{ @@ -11,6 +14,8 @@ mod cli; mod command; mod data_reader; mod logging; +mod parts; +mod train_test_iterator; #[derive(Parser, Debug)] #[command(version, about, long_about = None, styles=crate::cli::clap_styles())] @@ -60,3 +65,66 @@ fn main() { std::process::exit(1); } } + +// +// +// + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum InferenceDevice { + #[default] + Cpu, + Cuda(usize), + Wgpu(usize), +} + +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +#[error("{0}")] +pub struct ParseDeviceError(String); + +impl std::str::FromStr for InferenceDevice { + type Err = ParseDeviceError; + + fn from_str(s: &str) -> Result { + let s = s.to_lowercase(); + + if s == "cpu" { + return Ok(InferenceDevice::Cpu); + } + + if let Some(index_str) = s.strip_prefix("cuda:") { + return match index_str.parse::() { + Ok(index) => Ok(InferenceDevice::Cuda(index)), + Err(_) => Err(ParseDeviceError(format!( + "Invalid device index: '{}'", + index_str + ))), + }; + } + + if let Some(index_str) = s.strip_prefix("wgpu:") { + return match index_str.parse::() { + Ok(index) => Ok(InferenceDevice::Wgpu(index)), + Err(_) => Err(ParseDeviceError(format!( + "Invalid device index: '{}'", + index_str + ))), + }; + } + + return Err(ParseDeviceError(format!( + "Invalid device format: '{}'. Expected 'cpu', 'cuda:N', 'wgpu:N'", + s + ))); + } +} + +impl<'de> Deserialize<'de> for InferenceDevice { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + s.parse().map_err(serde::de::Error::custom) + } +} diff --git a/crates/llmfs/src/parts/attention.rs b/crates/llmfs/src/parts/attention.rs new file mode 100644 index 0000000..6845369 --- /dev/null +++ b/crates/llmfs/src/parts/attention.rs @@ -0,0 +1,228 @@ +use burn::{ + Tensor, + config::Config, + module::{Module, Param, ParamId}, + nn::Dropout, + prelude::Backend, + tensor::{Bool, Distribution, activation::softmax}, +}; +use std::f32; + +#[derive(Debug, Config)] +pub struct MultiheadAttentionConfig { + pub context_size: usize, + pub embed_dim: usize, + + pub n_heads: usize, + pub head_dim: usize, + pub drop_rate: f64, +} + +impl MultiheadAttentionConfig { + pub fn init(&self, device: &B::Device) -> MultiheadAttention { + let total_dim = self.head_dim * self.n_heads; + let embedding_dim = self.embed_dim; + + MultiheadAttention { + n_heads: self.n_heads, + head_dim: self.head_dim, + + w_query: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([embedding_dim, total_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [self.embed_dim, total_dim].into(), + ), + + w_key: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([embedding_dim, total_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [self.embed_dim, total_dim].into(), + ), + + w_value: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([embedding_dim, total_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [self.embed_dim, total_dim].into(), + ), + + w_output: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([total_dim, total_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [total_dim, total_dim].into(), + ), + + dropout: Dropout { + prob: self.drop_rate, + }, + + utri_mask: Tensor::::tril_mask( + [self.context_size, self.context_size], + 0, + device, + ), + } + } +} + +/// Multihead attention. +/// +/// Equivalent to many stacked CausalAttention layers. +/// These are packed inside one big tensor for efficiency. +#[derive(Module, Debug)] +pub struct MultiheadAttention { + n_heads: usize, + head_dim: usize, + + // Can also use Linear layers with disabled bias + // (they may also have a better initialization routine) + // TODO: see source code, make this equivalent + /// Query weight matrices for each head, stacked on the last dimension. + /// (so that shape is [tokens, n_heads * head_dim]) + /// + /// Intuitively, this learns "what question to ask about the text" + /// for a given query token. (e.g, "it" -> what does "it" refer to?) + w_query: Param>, + + /// Key weight matrices for each head, stacked on the last dimension. + /// (so that shape is [tokens, n_heads * head_dim]) + /// + /// Intuitively, this learns what properties a certain token + /// has when it appears as a context (non-query) token. + w_key: Param>, + + /// Value weight matrices for each head, stacked on the last dimension. + /// (so that shape is [tokens, n_heads * head_dim]) + /// + /// Intuitively, ??? + w_value: Param>, + + /// Optional final projection. + /// Maps [total_dim, total_dim] to [total_dim, total_dim] + w_output: Param>, + + dropout: Dropout, + + /// Upper-triangular matrix of ones, excluding diagonal. + /// Used to mask future tokens. + utri_mask: Tensor, +} + +impl MultiheadAttention { + /// Compute self-attention vector for the given batch + /// + /// - input shape is [batch, token, token_dim] + /// - input shape is [batch, token, n_heads * head_dim] + pub fn forward(&self, input: Tensor) -> Tensor { + // Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok) + // But adds an "inner latent space" using Wq, Qk, and Wv. + // + // Multiple heads are batched into one tensor. + + let batch = input.dims()[0]; + let tokens = input.dims()[1]; + + let w_query = self + .w_query + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + + let w_key = self + .w_key + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + + let w_value = self + .w_value + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + + let w_output = self + .w_output + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + + // Map batch to inner latent space. + // shape: [batch, token, inner_dim] + let queries = input.clone().matmul(w_query); + let keys = input.clone().matmul(w_key); + let values = input.clone().matmul(w_value); + + // Split head dimensions + let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]); + let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]); + let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]); + + // from: [batch, tok, head, head_dim] + // to: [batch, head, tok, head_dim] + let keys = keys.swap_dims(1, 2); + let values = values.swap_dims(1, 2); + let queries = queries.swap_dims(1, 2); + + // Compute attention scores for each head + // (cosine similarity of each query token to each context token, per head) + // + // lhs shape: [batch, head, tok, head_dim] + // rhs shape: [batch, head, head_dim, tok] + // output shape: [batch, head, query_token, context_token] + let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3)); + + let mask = self + .utri_mask + .clone() + .slice([0..tokens, 0..tokens]) + .unsqueeze_dim::<3>(0) + .unsqueeze_dim::<4>(0) + .expand(attn_scores.shape()); + + // Mask out future tokens by filling + // upper-triangular with -inf, which becomes 0.0 after softmax. + let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY); + + // Normalize attn weights. + // + // Divide by sqrt(inner_dim) because... + // - dot products get larger with larger dimensions + // - this causes softmax to "saturate", making all other values very small + // - which makes gradients vanish during training + let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3); + let attn_weights = self.dropout.forward(attn_weights); + + // lhs shape: [batch, head, query_token, context_token] + // rhs shape: [batch, head, tok, head_dim] + // matmul shape: [batch, head, tok, head_dim] + // out shape: [batch, tok, head, head_dim] + let context_vec = attn_weights.matmul(values).swap_dims(1, 2); + + // shape: [batch, tok, stacked_dim] + let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]); + + // Apply final projection (optional) + let context_vec = context_vec.matmul(w_output); + + return context_vec; + } +} diff --git a/crates/llmfs/src/parts/mod.rs b/crates/llmfs/src/parts/mod.rs new file mode 100644 index 0000000..3cdca3e --- /dev/null +++ b/crates/llmfs/src/parts/mod.rs @@ -0,0 +1,5 @@ +mod attention; +pub use attention::*; + +mod model; +pub use model::*; diff --git a/crates/llmfs/src/parts/model.rs b/crates/llmfs/src/parts/model.rs new file mode 100644 index 0000000..aacccf9 --- /dev/null +++ b/crates/llmfs/src/parts/model.rs @@ -0,0 +1,194 @@ +use burn::{ + Tensor, + config::Config, + module::{Module, Param, ParamId}, + nn::{ + Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, + loss::CrossEntropyLossConfig, + transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}, + }, + prelude::Backend, + tensor::{Distribution, Int}, +}; +use burn_train::ClassificationOutput; + +use crate::parts::{MultiheadAttention, MultiheadAttentionConfig}; + +#[derive(Debug, Config)] +pub struct GptModelConfig { + /// Number of tokens + pub vocab_size: u32, + + /// Maximum number of input tokens with positional embeddings + pub context_size: usize, + + /// Dimension of each token's embedding + pub embed_dim: usize, + + /// Number of attention heads + pub n_heads: usize, + + /// Dimension of each attn head + pub head_dim: usize, + + /// Number of transformer blocks + #[config(default = 12)] + pub n_layers: usize, + + #[config(default = 0.1)] + pub embed_drop_rate: f64, + + #[config(default = 0.1)] + pub attention_drop_rate: f64, + + #[config(default = 0.1)] + pub shortcut_drop_rate: f64, +} + +impl GptModelConfig { + pub fn init(&self, device: &B::Device) -> GptModel { + let out_head_shape = [self.embed_dim, self.vocab_size as usize]; + + GptModel { + embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim) + .init(device), + + embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device), + + embedder_drop: Dropout { + prob: self.embed_drop_rate, + }, + + trf_blocks: (0..self.n_layers) + .map(|_| TransformerBlock::new(&self, device)) + .collect(), + + final_norm: LayerNormConfig::new(self.embed_dim).init(device), + + out_head: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random(out_head_shape, Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + out_head_shape.into(), + ), + } + } +} + +#[derive(Module, Debug)] +pub struct GptModel { + embedder_tok: Embedding, + embedder_pos: Embedding, + embedder_drop: Dropout, + + trf_blocks: Vec>, + final_norm: LayerNorm, + out_head: Param>, +} + +impl GptModel { + pub fn forward(&self, input: Tensor) -> Tensor { + let n_tokens = input.shape()[1]; + + let embed_tok = self.embedder_tok.forward(input.clone()); + let embed_pos = self + .embedder_pos + .forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0)); + + let x = embed_tok + embed_pos; + let x = self.embedder_drop.forward(x); + let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x)); + let x = self.final_norm.forward(x); + let logits = x.matmul(self.out_head.val().unsqueeze_dim(0)); + + return logits; + } + + pub fn forward_train( + &self, + inputs: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + // shape: [batch, n_tokens, n_vocabulary] + let output = self.forward(inputs); + + // Get last token + // shape: [batch, n_vocabulary] + let output = output.slice_dim(1, -1).squeeze_dim::<2>(1); + + let loss = CrossEntropyLossConfig::new() + .init(&targets.device()) + .forward(output.clone(), targets.clone()); + + ClassificationOutput { + loss, + output, + targets, + } + } +} + +#[derive(Module, Debug)] +pub struct TransformerBlock { + attention: MultiheadAttention, + + /// TODO: wtf? + ff: PositionWiseFeedForward, + + /// TODO: wtf? + norm_a: LayerNorm, + norm_b: LayerNorm, + + drop_shortcut: Dropout, +} + +impl TransformerBlock { + pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self { + Self { + attention: MultiheadAttentionConfig { + embed_dim: cfg.embed_dim, + head_dim: cfg.head_dim, + n_heads: cfg.n_heads, + context_size: cfg.context_size, + drop_rate: cfg.attention_drop_rate, + } + .init(device), + + ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim) + .with_dropout(0.0) + .init(device), + + norm_a: LayerNormConfig::new(cfg.embed_dim).init(device), + norm_b: LayerNormConfig::new(cfg.embed_dim).init(device), + + drop_shortcut: Dropout { + prob: cfg.shortcut_drop_rate, + }, + } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let input = { + let shortcut = input.clone(); + let x = self.norm_a.forward(input); + let x = self.attention.forward(x); + let x = self.drop_shortcut.forward(x); + x + shortcut + }; + + let input = { + // TODO: wtf? + let shortcut = input.clone(); + let x = self.norm_b.forward(input); + let x = self.ff.forward(x); + let x = self.drop_shortcut.forward(x); + x + shortcut + }; + + return input; + } +} diff --git a/crates/llmfs/src/train_test_iterator.rs b/crates/llmfs/src/train_test_iterator.rs new file mode 100644 index 0000000..65a06c5 --- /dev/null +++ b/crates/llmfs/src/train_test_iterator.rs @@ -0,0 +1,164 @@ +use ahash::AHasher; +use anyhow::Result; +use burn::{ + Tensor, + prelude::{Backend, ToElement}, + tensor::Int, +}; +use ndarray::{Array1, Array2}; +use std::{collections::VecDeque, hash::Hasher, path::Path}; +use tokenizer::Tokenizer; + +use crate::data_reader::{DataReader, DataReaderError}; + +#[derive(Debug, Clone)] +pub struct TrainBatch { + /// Input texts. + /// shape: [batch, context_size] + pub inputs: Tensor, + + /// Correct next token for each input. + /// shape: [batch] + pub targets: Tensor, +} + +/// Read texts from a [DataReader], then +/// - extract context windows +/// - deterministically classify these as "test" or "train" +/// - batch output into tensors of token ids +pub struct TrainTestIterator<'a, B: Backend> { + reader: DataReader, + + tokenizer: &'a Tokenizer, + eval: bool, + device: &'a B::Device, + + batch_size: usize, + context_size: usize, + eval_frac: f64, + eval_salt: String, + + // Tokenized input/output pairs + pairs: VecDeque<(Vec, u32)>, + error: bool, +} + +impl<'a, B: Backend> TrainTestIterator<'a, B> { + pub fn new( + data_dir: impl AsRef, + tokenizer: &'a Tokenizer, + eval: bool, + batch_size: usize, + context_size: usize, + eval_frac: f64, + eval_salt: impl Into, + readers: usize, + device: &'a B::Device, + ) -> Result { + let reader = DataReader::new(readers.max(1), data_dir)?; + + Ok(Self { + reader, + tokenizer, + eval, + device, + + batch_size, + context_size, + eval_frac, + eval_salt: eval_salt.into(), + + pairs: VecDeque::new(), + error: false, + }) + } +} + +impl Iterator for TrainTestIterator<'_, B> { + type Item = Result, DataReaderError>; + + fn next(&mut self) -> Option { + if self.error { + return None; + } + + let mut inputs = Vec::with_capacity(self.batch_size); + let mut targets = Vec::with_capacity(self.batch_size); + let stride = self.context_size; + + while inputs.len() < self.batch_size { + match self.pairs.pop_front() { + Some((i, t)) => { + // train/test split + { + let mut hasher = AHasher::default(); + hasher.write(self.eval_salt.as_bytes()); + + // Don't care about endianness, ahash output is unstable anyway + hasher.write(unsafe { std::mem::transmute(&i[..]) }); + hasher.write_u32(t); + + let train = // is this point in the training set? + hasher.finish() > (u64::MAX as f64 * self.eval_frac).to_u64(); + + if train && self.eval { + continue; + } + } + + inputs.push(i); + targets.push(t); + } + + None => { + let text = match self.reader.next() { + None => break, + Some(Ok(x)) => x, + Some(Err(x)) => { + self.error = true; + return Some(Err(x)); + } + }; + + let emb = self.tokenizer.encode(&text); + + // Skip small texts + // + // TODO: do this better + // TODO: maybe using <|bos|>? + // TODO: non-uniform batches? + if emb.len() < self.context_size { + continue; + } + + let pairs = emb + .windows(self.context_size + 1) + .step_by(stride) + .map(|x| (x[..self.context_size].to_vec(), x[self.context_size])); + + self.pairs.extend(pairs); + } + } + } + + if inputs.is_empty() { + return None; + } + + let shape = [inputs.len(), self.context_size]; + + // Arrange data in memory + let inputs: Array2 = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]); + let targets: Array1 = Array1::from_vec(targets); + + // Create tensors on gpu + #[expect(clippy::unwrap_used)] + let inputs = + Tensor::::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape); + + #[expect(clippy::unwrap_used)] + let targets = Tensor::::from_ints(targets.as_slice().unwrap(), self.device); + + return Some(Ok(TrainBatch { inputs, targets })); + } +} diff --git a/crates/tokenizer/src/tokenizer.rs b/crates/tokenizer/src/tokenizer.rs index 2c28621..11fb5ea 100644 --- a/crates/tokenizer/src/tokenizer.rs +++ b/crates/tokenizer/src/tokenizer.rs @@ -19,8 +19,7 @@ use tracing::{debug, info}; use crate::{progress_big, split::regex_segment}; -// TODO: -// - maybe don't use regex +// Maybe don't use regex for performance? #[derive(Debug, Clone, thiserror::Error)] pub enum TokenizerTrainError {