diff --git a/Cargo.lock b/Cargo.lock index 0767f61..c6bf878 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "addr2line" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" +dependencies = [ + "gimli", +] + [[package]] name = "adler2" version = "2.0.1" @@ -273,6 +282,21 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "backtrace" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-link 0.2.1", +] + [[package]] name = "base64" version = "0.22.1" @@ -383,9 +407,8 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "burn" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0291ea5c68786545e239a02f63331cfe39da7485164ae05197d5be6f148d0557" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-autodiff", "burn-candle", @@ -404,60 +427,64 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917423a74bf4d39f17a6799089869648e3d2b6ac89d93901aab4aeb9a7f82138" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", - "burn-tensor", + "burn-backend", + "burn-std", "derive-new", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "num-traits", + "parking_lot", "portable-atomic", "spin", ] [[package]] -name = "burn-candle" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2891811d41ae30b5f1f660e7615b757b2cb4128af5e311b213656de3875e4acb" +name = "burn-backend" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", - "burn-tensor", - "candle-core", + "burn-std", + "bytemuck", + "cubecl", "derive-new", - "half", + "hashbrown 0.16.1", + "num-traits", + "rand", + "rand_distr", + "serde", + "thiserror 2.0.17", ] [[package]] -name = "burn-common" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eb445304e4f91f8633d23c9a5258cd93639d13ce2ee47d4821fd519b683bf02" +name = "burn-candle" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "cubecl-common", - "rayon", - "serde", + "burn-backend", + "burn-std", + "candle-core", + "derive-new", ] [[package]] name = "burn-core" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20c93e754864080a8c27b9a47e3b6f7d79013cf82c9ce00ed57c9ba51a3e34c5" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "ahash", "bincode", - "burn-common", + "burn-dataset", "burn-derive", + "burn-std", "burn-tensor", "data-encoding", "derive-new", "flate2", "half", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "num-traits", "portable-atomic", @@ -473,84 +500,82 @@ dependencies = [ [[package]] name = "burn-cpu" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4807930d243f1aa9dde99db372af56ac532cc6635fd3187156aee375fbadc07" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", - "bytemuck", "cubecl", - "derive-new", - "half", - "log", ] [[package]] name = "burn-cubecl" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd16308b7b0291c77f2d7acf428bc8254ec3db88a430a26cf3d3b0b63ae2d46" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-cubecl-fusion", "burn-fusion", "burn-ir", - "burn-tensor", - "bytemuck", + "burn-std", "cubecl", - "cubecl-quant", + "cubek", "derive-new", "futures-lite", - "half", - "hashbrown 0.15.5", "log", - "num-traits", - "rand", "serde", - "spin", "text_placeholder", ] [[package]] name = "burn-cubecl-fusion" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc21cf88201dfbf242cadb638a0cc924010727fc37d6a719f7e10548b339c63a" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-fusion", "burn-ir", - "burn-tensor", + "burn-std", "cubecl", - "cubecl-quant", + "cubek", "derive-new", - "half", "serde", ] [[package]] name = "burn-cuda" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e104dcf07eac70c7b5864b51d792df3360b11b00febb60543b4283bb414bb61" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", - "bytemuck", "cubecl", +] + +[[package]] +name = "burn-dataset" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" +dependencies = [ + "csv", "derive-new", - "half", - "log", + "dirs", + "rand", + "rmp-serde", + "sanitize-filename", + "serde", + "serde_json", + "strum", + "tempfile", + "thiserror 2.0.17", ] [[package]] name = "burn-derive" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bcf49261de086b8206de6c8962d2adf23feb476119a18e384f5b2c9af07c0cf" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "derive-new", "proc-macro2", @@ -560,16 +585,13 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "662bf2679c04be34a0c3f1b11f77f6ff49456af1620d1eca311bc2562bbb56c9" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-ir", - "burn-tensor", "derive-new", - "half", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "serde", "spin", @@ -577,45 +599,44 @@ dependencies = [ [[package]] name = "burn-ir" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9161239d5691c4ab6f470f2c65aaec5c0a7c1f0b0da390700bcd59f5a77d1d7b" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-tensor", - "hashbrown 0.15.5", + "burn-backend", + "hashbrown 0.16.1", "portable-atomic-util", "serde", ] [[package]] name = "burn-ndarray" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b78bcf4a3508043342f918e796dc79108b5f3252398403eb73952847e7683374" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "atomic_float", "burn-autodiff", - "burn-common", + "burn-backend", "burn-ir", - "burn-tensor", + "burn-std", + "bytemuck", "const-random", - "derive-new", + "itertools 0.14.0", "libm", "macerator", "matrixmultiply", - "ndarray", + "ndarray 0.17.1", "num-traits", "paste", "portable-atomic-util", "rand", - "spin", + "rayon", + "seq-macro", ] [[package]] name = "burn-nn" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc7829c87c4dd6c7929b50fd981e7e8d1b77414323da30ce2067a3e8b7ea422b" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-core", "num-traits", @@ -623,13 +644,12 @@ dependencies = [ [[package]] name = "burn-optim" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31758c02e50247f12457fca1905ed8684ac1b1c5292e10cbbfffb9fa0048d4bd" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-core", "derive-new", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "log", "num-traits", "serde", @@ -637,81 +657,99 @@ dependencies = [ [[package]] name = "burn-rocm" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e1ceb87b6e7349b42d7995477c9a69d0e6c458c64eafa10af3b8b9070f260aa" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", - "bytemuck", "cubecl", - "derive-new", - "half", - "log", ] [[package]] name = "burn-router" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f40403c500b5df380bee47aa0f23032350bdfde5402812d6fcec4d6ff6fbad" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", + "burn-backend", "burn-ir", - "burn-tensor", - "hashbrown 0.15.5", + "burn-std", + "hashbrown 0.16.1", "log", "spin", ] +[[package]] +name = "burn-std" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" +dependencies = [ + "bytemuck", + "bytes", + "cubecl", + "cubecl-common", + "half", + "num-traits", + "serde", +] + [[package]] name = "burn-store" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a2a163486242fcb0c6e2cb89c5a803ab8588673652bb46ecd7af6378d06152f" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ "burn-core", "burn-nn", "burn-tensor", "byteorder", + "bytes", "half", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "memmap2", "regex", - "safetensors 0.6.2", + "safetensors 0.7.0", + "textdistance", ] [[package]] name = "burn-tensor" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df8861f7c21d3b07a2b19d028f6eb8903990949708b2ec825559b5200786877c" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ - "burn-common", - "bytemuck", + "burn-backend", + "burn-std", "colored", - "cubecl", - "cubecl-quant", "derive-new", - "half", - "hashbrown 0.15.5", "num-traits", - "rand", - "rand_distr", "serde", - "serde_bytes", +] + +[[package]] +name = "burn-train" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" +dependencies = [ + "async-channel", + "burn-core", + "burn-ndarray", + "burn-optim", + "derive-new", + "log", + "rstest", + "serde", + "tracing-appender", + "tracing-core", + "tracing-subscriber", ] [[package]] name = "burn-wgpu" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17aeaa2eadaa4831a64672b99f62ffcdf4874fe4757080633d8a6c4452e2b38" +version = "0.20.0-pre.5" +source = "git+https://github.com/tracel-ai/burn.git#fef06c92053624de6ba7e4eccabf5199b33c77ca" dependencies = [ + "burn-backend", "burn-cubecl", "burn-fusion", - "burn-tensor", "cubecl", ] @@ -746,6 +784,9 @@ name = "bytes" version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +dependencies = [ + "portable-atomic", +] [[package]] name = "candle-core" @@ -1057,6 +1098,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1099,19 +1149,36 @@ dependencies = [ ] [[package]] -name = "cubecl" -version = "0.8.1" +name = "csv" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8b7c74ecaca9356c9ae79d0ebf1db04f02bd98be09eea61f51d73373dffe758" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" dependencies = [ - "cubecl-convolution", + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + +[[package]] +name = "cubecl" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" +dependencies = [ + "cfg_aliases", "cubecl-core", "cubecl-cpu", "cubecl-cuda", "cubecl-hip", - "cubecl-matmul", - "cubecl-random", - "cubecl-reduce", "cubecl-runtime", "cubecl-std", "cubecl-wgpu", @@ -1120,11 +1187,12 @@ dependencies = [ [[package]] name = "cubecl-common" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4556981155bffc057a8effcd4549b52b51df3e9edec43af6ccae2dd03fc8fbff" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ + "backtrace", "bytemuck", + "bytes", "cfg-if", "cfg_aliases", "derive-new", @@ -1152,30 +1220,10 @@ dependencies = [ "web-time", ] -[[package]] -name = "cubecl-convolution" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27c624ec400b7203673bf2db86d7ff30d1384839d497d2dd029c19b1b7371e0d" -dependencies = [ - "bytemuck", - "cubecl-common", - "cubecl-core", - "cubecl-matmul", - "cubecl-random", - "cubecl-reduce", - "cubecl-runtime", - "cubecl-std", - "half", - "pretty_assertions", - "serde", -] - [[package]] name = "cubecl-core" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffc10af538ee74535cda260e581f5a177c243803dd30b698934a515f0114b55" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -1198,9 +1246,8 @@ dependencies = [ [[package]] name = "cubecl-cpp" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d630e4d10cdd3af268ac753914ca79b48f01d1e36c5b5039970a817acc925fea" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", @@ -1215,17 +1262,13 @@ dependencies = [ [[package]] name = "cubecl-cpu" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac1693555277d74152afb61a23e30d1f17d72cebd317a648faf50a8e69380f08" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", - "cubecl-convolution", "cubecl-core", - "cubecl-matmul", "cubecl-opt", - "cubecl-reduce", "cubecl-runtime", "cubecl-std", "derive-new", @@ -1239,9 +1282,8 @@ dependencies = [ [[package]] name = "cubecl-cuda" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67215fcd552a9e8bc68494a71cf2979f2e2bbcbda60f0695f56f86705b89ed5f" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", @@ -1257,16 +1299,14 @@ dependencies = [ [[package]] name = "cubecl-hip" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5e2e6a257f702fb2eb6f24e640e228a94695e4a4c73a4c549578cbb02ad4ec5" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "bytemuck", "cubecl-common", "cubecl-core", "cubecl-cpp", "cubecl-hip-sys", - "cubecl-quant", "cubecl-runtime", "derive-new", "half", @@ -1287,14 +1327,14 @@ dependencies = [ [[package]] name = "cubecl-ir" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf5d3aa7857e6aee1622aef128d6ad8d9289ed57362b4e65d10cc182aafc585f" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "cubecl-macros-internal", "derive-new", "derive_more", + "enumset", "float-ord", "fnv", "half", @@ -1307,9 +1347,8 @@ dependencies = [ [[package]] name = "cubecl-macros" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5200fb619be424749901e3c6e8e66ae71146c8f83636a74f171bd980cba379d7" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "darling 0.21.3", @@ -1323,9 +1362,8 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a1b673f303396fba18df83368aa4eced474584f1bca34852dccc42bd4ff050c" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "darling 0.21.3", "proc-macro2", @@ -1333,29 +1371,10 @@ dependencies = [ "syn", ] -[[package]] -name = "cubecl-matmul" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cf0a00609a249d5357c27cafea477f35218579db2ab00582d8d5800be4a5a3" -dependencies = [ - "bytemuck", - "cubecl-common", - "cubecl-core", - "cubecl-random", - "cubecl-reduce", - "cubecl-runtime", - "cubecl-std", - "half", - "pretty_assertions", - "serde", -] - [[package]] name = "cubecl-opt" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870ca4b52f9eebd358c9b360b89cdc9f82bde05682db63f0e90c666b3c85a04d" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "cubecl-core", @@ -1369,57 +1388,10 @@ dependencies = [ "type-map", ] -[[package]] -name = "cubecl-quant" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9be3e1202c219078d85dbad7f30d1195fe4f9d42cbfad2c94ab0ea1a6d9f01f6" -dependencies = [ - "cubecl-common", - "cubecl-core", - "cubecl-runtime", - "cubecl-std", - "half", - "serde", -] - -[[package]] -name = "cubecl-random" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a293a05caa68663675823bab66205bca094a21a2c0f6686ad9f20b392516179" -dependencies = [ - "cubecl-common", - "cubecl-core", - "cubecl-runtime", - "cubecl-std", - "half", - "num-traits", - "rand", - "serde", -] - -[[package]] -name = "cubecl-reduce" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53306ace81f6262f7ae794370f47e6b5019842b27e8800240e5b039386b3ac3a" -dependencies = [ - "cubecl-core", - "cubecl-runtime", - "cubecl-std", - "half", - "num-traits", - "pretty_assertions", - "rand", - "serde", -] - [[package]] name = "cubecl-runtime" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91b823bb5899a6fa8809bf7aa36f93f72ced6de58ab9d6edea2c730b235eeda3" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "async-channel", "bytemuck", @@ -1430,7 +1402,7 @@ dependencies = [ "derive-new", "dirs", "enumset", - "foldhash", + "foldhash 0.1.5", "hashbrown 0.15.5", "log", "md5", @@ -1445,15 +1417,15 @@ dependencies = [ [[package]] name = "cubecl-std" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24536998f9fff84f9a1dd2a90f981d5aa4d15eb35cddec5021c4fcf977d2e75e" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "cubecl-common", "cubecl-core", "cubecl-runtime", - "foldhash", + "foldhash 0.1.5", "half", + "num-traits", "paste", "serde", "spin", @@ -1462,9 +1434,8 @@ dependencies = [ [[package]] name = "cubecl-wgpu" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59a7d737259a784247595e2f0cc5a97d3e50f45cdaefbd4cc7d7fd2126f7a58" +version = "0.9.0-pre.5" +source = "git+https://github.com/tracel-ai/cubecl?rev=afa4a91a876e18c54153a8094fca469d7ba0817a#afa4a91a876e18c54153a8094fca469d7ba0817a" dependencies = [ "async-channel", "bytemuck", @@ -1482,11 +1453,103 @@ dependencies = [ "wgpu", ] +[[package]] +name = "cubek" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "cubek-attention", + "cubek-convolution", + "cubek-matmul", + "cubek-quant", + "cubek-random", + "cubek-reduce", +] + +[[package]] +name = "cubek-attention" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-matmul", + "cubek-random", + "half", + "serde", +] + +[[package]] +name = "cubek-convolution" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-matmul", + "derive-new", + "half", + "serde", +] + +[[package]] +name = "cubek-matmul" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-random", + "cubek-reduce", + "half", + "serde", +] + +[[package]] +name = "cubek-quant" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "cubecl-common", + "half", + "serde", +] + +[[package]] +name = "cubek-random" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "cubecl-common", + "half", + "num-traits", + "rand", + "serde", +] + +[[package]] +name = "cubek-reduce" +version = "0.0.1" +source = "git+https://github.com/tracel-ai/cubek?rev=4442ccf777030c3ae0dcb3ebadcc85066d96703f#4442ccf777030c3ae0dcb3ebadcc85066d96703f" +dependencies = [ + "cubecl", + "half", + "num-traits", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "cudarc" -version = "0.17.8" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf99ab37ee7072d64d906aa2dada9a3422f1d975cdf8c8055a573bc84897ed8" +checksum = "3aa12038120eb13347a6ae2ffab1d34efe78150125108627fd85044dd4d6ff1e" dependencies = [ "libloading", ] @@ -1573,6 +1636,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" +[[package]] +name = "deranged" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive-new" version = "0.7.0" @@ -1624,12 +1696,6 @@ version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "abd57806937c9cc163efc8ea3910e00a62e2aeb0b8119f1793a978088f8f6b04" -[[package]] -name = "diff" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" - [[package]] name = "digest" version = "0.10.7" @@ -1809,6 +1875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25b07a8dfbbbfc0064c0a6bdf9edcf966de6b1c33ce344bdeca3b41615452634" dependencies = [ "enumset_derive", + "serde", ] [[package]] @@ -1955,6 +2022,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -2064,6 +2137,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -2355,6 +2434,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" + [[package]] name = "gl_generator" version = "0.14.0" @@ -2496,7 +2581,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", "serde", ] @@ -2505,6 +2590,13 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", + "serde", + "serde_core", +] [[package]] name = "heck" @@ -3044,15 +3136,18 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" name = "llmfs" version = "0.0.1" dependencies = [ + "ahash", "anstyle", "anyhow", "burn", + "burn-train", "clap", "futures-util", "indicatif", - "ndarray", + "ndarray 0.16.1", "parking_lot", "parquet", + "rand", "rayon", "reqwest", "serde", @@ -3149,7 +3244,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" dependencies = [ "autocfg", + "num_cpus", + "once_cell", "rawpointer", + "thread-tree", ] [[package]] @@ -3302,6 +3400,22 @@ dependencies = [ "serde", ] +[[package]] +name = "ndarray" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7c9125e8f6f10c9da3aad044cc918cf8784fa34de857b1aa68038eb05a50a9" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", + "rayon", +] + [[package]] name = "ndk-sys" version = "0.6.0+11769913" @@ -3379,6 +3493,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -3480,6 +3600,15 @@ dependencies = [ "objc2-core-foundation", ] +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -3689,6 +3818,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -3704,16 +3839,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" -[[package]] -name = "pretty_assertions" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" -dependencies = [ - "diff", - "yansi", -] - [[package]] name = "prettyplease" version = "0.2.37" @@ -3994,6 +4119,12 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "renderdoc-sys" version = "1.1.0" @@ -4084,6 +4215,41 @@ dependencies = [ "serde", ] +[[package]] +name = "rstest" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49" +dependencies = [ + "futures-timer", + "futures-util", + "rstest_macros", +] + +[[package]] +name = "rstest_macros" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -4177,10 +4343,11 @@ dependencies = [ [[package]] name = "safetensors" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" dependencies = [ + "hashbrown 0.16.1", "serde", "serde_json", ] @@ -4632,6 +4799,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "textdistance" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa672c55ab69f787dbc9126cc387dbe57fdd595f585e4524cf89018fa44ab819" + [[package]] name = "thiserror" version = "1.0.69" @@ -4672,6 +4845,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread-tree" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630" +dependencies = [ + "crossbeam-channel", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -4692,6 +4874,37 @@ dependencies = [ "ordered-float 2.10.1", ] +[[package]] +name = "time" +version = "0.3.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" + +[[package]] +name = "time-macros" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -4989,6 +5202,18 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf" +dependencies = [ + "crossbeam-channel", + "thiserror 2.0.17", + "time", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.31" @@ -5997,12 +6222,6 @@ version = "0.8.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ae8337f8a065cfc972643663ea4279e04e7256de865aa66fe25cec5fb912d3f" -[[package]] -name = "yansi" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" - [[package]] name = "yoke" version = "0.7.5" diff --git a/Cargo.toml b/Cargo.toml index 82380cd..0c6d59a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,10 +75,12 @@ compact_str = "0.9.0" dary_heap = "0.3.8" fancy-regex = "0.16.2" indicatif = { version = "0.18.3", features = ["improved_unicode"] } +itertools = "0.14.0" futures-util = "0.3.31" ndarray = { version = "0.16.1", features = ["serde"] } parking_lot = "0.12.5" parquet = "56.2.0" +rand = "0.9.2" rayon = "1.11.0" reqwest = { version = "0.12.24", features = ["json", "stream"] } serde = "1.0.228" @@ -91,7 +93,11 @@ tracing-indicatif = "0.3.13" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } url = "2.5.7" + +burn-train = { git = "https://github.com/tracel-ai/burn.git", default-features = false } + [workspace.dependencies.burn] -version = "0.19.1" +#version = "0.19.1" +git = "https://github.com/tracel-ai/burn.git" default-features = false -features = ["std", "fusion", "ndarray", "webgpu", "cuda"] +features = ["std", "fusion", "ndarray", "webgpu", "cuda", "autodiff"] diff --git a/crates/llmfs/Cargo.toml b/crates/llmfs/Cargo.toml index 69bd819..850317d 100644 --- a/crates/llmfs/Cargo.toml +++ b/crates/llmfs/Cargo.toml @@ -10,15 +10,18 @@ workspace = true [dependencies] tokenizer = { workspace = true } +ahash = { workspace = true } anstyle = { workspace = true } anyhow = { workspace = true } burn = { workspace = true } +burn-train = { workspace = true } clap = { workspace = true } futures-util = { workspace = true } indicatif = { workspace = true } ndarray = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true } +rand = { workspace = true } rayon = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } diff --git a/crates/llmfs/src/command/mod.rs b/crates/llmfs/src/command/mod.rs index 779132d..cfa67e1 100644 --- a/crates/llmfs/src/command/mod.rs +++ b/crates/llmfs/src/command/mod.rs @@ -15,7 +15,6 @@ pub enum SubCommand { #[command(flatten)] args: train_tokenizer::TrainTokenizerArgs, }, - /// Sample data SampleData { #[command(flatten)] diff --git a/crates/llmfs/src/command/sample_data.rs b/crates/llmfs/src/command/sample_data.rs index 01b7662..a91bb0f 100644 --- a/crates/llmfs/src/command/sample_data.rs +++ b/crates/llmfs/src/command/sample_data.rs @@ -1,22 +1,206 @@ +use ahash::AHasher; use anyhow::{Context, Result}; use burn::{ Tensor, - backend::{Cuda, cuda::CudaDevice}, - module::{Module, Param, ParamId}, + backend::{Autodiff, Cuda, cuda::CudaDevice}, + config::Config, + module::{AutodiffModule, Module, Param, ParamId}, nn::{ Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, + loss::CrossEntropyLossConfig, transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}, }, - prelude::Backend, + optim::{AdamConfig, GradientsParams, Optimizer}, + prelude::{Backend, ToElement}, tensor::{Bool, Distribution, Int, activation::softmax}, }; +use burn_train::ClassificationOutput; use clap::Args; use indicatif::MultiProgress; -use ndarray::Array2; -use std::{f32, fs::File, path::PathBuf}; +use ndarray::{Array1, Array2}; +use std::{ + collections::VecDeque, + f32, + fs::File, + hash::Hasher, + path::{Path, PathBuf}, +}; use tokenizer::Tokenizer; +use tracing::{debug, info}; -use crate::data_reader::DataReader; +use crate::data_reader::{DataReader, DataReaderError}; + +// Text generation routine + +/* +{ + let init = "Initial context. This is "; + let tokens = tokenizer.encode(&init); + + let n_tokens = tokens.len(); + let input: Array1 = Array1::from_vec(tokens); + let mut input: Tensor = + Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) + .reshape([n_tokens]); + + for _ in 0..100 { + let tokens: Vec = input.clone().to_data().convert::().into_vec().unwrap(); + println!("{:?}", tokens); + println!("{}", tokenizer.decode(&tokens)); + + // Crop idx to context size; + let batch = input + .clone() + .slice([0..config.context_size]) + .unsqueeze_dim(0); + + // shape: [tokens, vocab_size] + let logits = model.forward(batch).squeeze_dim::<2>(0); + + // shape: [vocab_size] + let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0); + + let probs = softmax(logits, 0); // shape: [n_tokens] + let id_next = probs.argmax(0); // shape: [1] + + input = Tensor::cat(vec![input.slice([1..]), id_next], 0); + } +} +*/ + +struct TrainTestIterator<'a, B: Backend> { + reader: DataReader, + + ccfg: &'a ComputeConfig, + mcfg: &'a GptModelConfig, + tokenizer: &'a Tokenizer, + eval: bool, + device: &'a B::Device, + + error: bool, + + // Tokenized input/output pairs + pairs: VecDeque<(Vec, u32)>, +} + +impl<'a, B: Backend> TrainTestIterator<'a, B> { + pub fn new( + data_dir: impl AsRef, + ccfg: &'a ComputeConfig, + mcfg: &'a GptModelConfig, + tokenizer: &'a Tokenizer, + eval: bool, + device: &'a B::Device, + ) -> Result { + let reader = DataReader::new(10, data_dir)?; + + Ok(Self { + reader, + ccfg, + mcfg, + tokenizer, + eval, + device, + + error: false, + pairs: VecDeque::new(), + }) + } +} + +impl Iterator for TrainTestIterator<'_, B> { + type Item = Result, DataReaderError>; + + fn next(&mut self) -> Option { + if self.error { + return None; + } + + let mut inputs = Vec::with_capacity(self.ccfg.batch_size); + let mut targets = Vec::with_capacity(self.ccfg.batch_size); + let stride = self.mcfg.context_size; + + while inputs.len() < self.ccfg.batch_size { + match self.pairs.pop_front() { + Some((i, t)) => { + // train/test split + { + let mut hasher = AHasher::default(); + hasher.write(self.ccfg.eval_salt.as_bytes()); + + // Don't care about endianness, ahash output is unstable anyway + hasher.write(unsafe { std::mem::transmute(&i[..]) }); + hasher.write_u32(t); + + let test = // is this point in the test set? + hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64(); + + if test ^ self.eval { + continue; + } + } + + inputs.push(i); + targets.push(t); + } + + None => { + let text = match self.reader.next() { + None => break, + Some(Ok(x)) => x, + Some(Err(x)) => { + self.error = true; + return Some(Err(x)); + } + }; + + let emb = self.tokenizer.encode(&text); + + // Skip small texts + // + // TODO: do this better + // TODO: maybe using <|bos|>? + // TODO: non-uniform batches? + if emb.len() < self.mcfg.context_size { + continue; + } + + let pairs = emb + .windows(self.mcfg.context_size + 1) + .step_by(stride) + .map(|x| { + ( + x[..self.mcfg.context_size].to_vec(), + x[self.mcfg.context_size], + ) + }); + + self.pairs.extend(pairs); + } + } + } + + if inputs.is_empty() { + return None; + } + + let shape = [inputs.len(), self.mcfg.context_size]; + + // Arrange data in memory + let inputs: Array2 = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]); + let targets: Array1 = Array1::from_vec(targets); + + // Create tensors on gpu + #[expect(clippy::unwrap_used)] + let inputs = + Tensor::::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape); + + #[expect(clippy::unwrap_used)] + let targets = Tensor::::from_ints(targets.as_slice().unwrap(), self.device); + + return Some(Ok(TrainBatch { inputs, targets })); + } +} #[derive(Debug, Args, Clone)] @@ -32,190 +216,142 @@ pub struct SampleDataArgs { /// How many texts to return #[clap(long, short = 'n', default_value = "10")] n: usize, - - /// How many texts to skip - #[clap(long, short = 's', default_value = "0")] - skip: usize, } -#[derive(Debug, Clone)] -pub struct Config { - /// Number of tokens - pub vocab_size: u32, - - /// Maximum number of input tokens with positional embeddings - pub context_size: usize, - - /// Dimension of each token's embedding - pub embed_dim: usize, - - /// Number of attention heads - pub n_heads: usize, - - /// Dimension of each attn head - pub head_dim: usize, - - /// Number of transformer blocks - pub n_layers: usize, - - pub embed_drop_rate: f64, - pub attention_drop_rate: f64, - pub shortcut_drop_rate: f64, +pub struct ComputeConfig { + pub batch_size: usize, + pub eval_frac: f64, + pub eval_salt: String, } impl SampleDataArgs { pub fn run(self, _mp: Option) -> Result<()> { let device = CudaDevice::new(0); - - let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?; + //let device = WgpuDevice::DiscreteGpu(0); let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?; let tokenizer: Tokenizer = serde_json::from_reader(tokenizer).context("while loading tokenizer")?; - let config = Config { + let ccfg = ComputeConfig { + batch_size: 10, + eval_frac: 0.1, + eval_salt: "salt".into(), + }; + + let mcfg = GptModelConfig { vocab_size: tokenizer.vocab_size(), - context_size: 4, + context_size: 256, embed_dim: 768, n_heads: 12, head_dim: 64, // = 768 / 12 - n_layers: 12, + n_layers: 1, embed_drop_rate: 0.1, attention_drop_rate: 0.1, shortcut_drop_rate: 0.1, }; - let stride = config.context_size; - - let batch_size = 10; - let mut input_batch = Vec::with_capacity(batch_size); - let mut output_batch = Vec::with_capacity(batch_size); - - #[expect(clippy::unwrap_used)] // Lazy error handling - let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n); - - let model = GptModel::new(&config, &device); - - // Text generation routine + let mut model: GptModel> = mcfg.init(&device); /* - { - let init = "Initial context. This is "; - let tokens = tokenizer.encode(&init); + let loader_train = DataLoaderBuilder::new(batcher.clone()) + .batch_size(ccfg.batch_size) + //.shuffle(config.seed) + .num_workers(5) + .build(Loader::new(&self.data_dir).context("while initializing loader")?); - let n_tokens = tokens.len(); - let input: Array1 = Array1::from_vec(tokens); - let mut input: Tensor = - Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) - .reshape([n_tokens]); + let loader_test = DataLoaderBuilder::new(batcher) + .batch_size(ccfg.batch_size) + //.shuffle(config.seed) + .num_workers(5) + .build(Loader::new(&self.data_dir).context("while initializing loader")?); - for _ in 0..100 { - let tokens: Vec = input.clone().to_data().convert::().into_vec().unwrap(); - println!("{:?}", tokens); - println!("{}", tokenizer.decode(&tokens)); + let learner = LearnerBuilder::new("./tmp") + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .learning_strategy(LearningStrategy::SingleDevice(device.clone())) + .num_epochs(10) + .summary() + .build(model, AdamConfig::new().init(), 1e-4); - // Crop idx to context size; - let batch = input - .clone() - .slice([0..config.context_size]) - .unsqueeze_dim(0); - - // shape: [tokens, vocab_size] - let logits = model.forward(batch).squeeze_dim::<2>(0); - - // shape: [vocab_size] - let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0); - - let probs = softmax(logits, 0); // shape: [n_tokens] - let id_next = probs.argmax(0); // shape: [1] - - input = Tensor::cat(vec![input.slice([1..]), id_next], 0); - } - } + learner.fit(loader_train, loader_test); */ - for i in iter { - let tokens = tokenizer.encode(&i); + // Initialize optimizer + let mut optim = AdamConfig::new().init(); + let learning_rate = 1e-4; - // Skip small texts. - // TODO: do this better - // TODO: non-uniform batches? - if tokens.len() < config.context_size { - continue; + for epoch in 0..10 { + debug!("Running epoch {epoch}"); + + // Training phase + let mut train_loss_sum = 0.0; + let mut train_total = 0; + + for batch in + TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, false, &device) + .context("while initializing reader")? + { + let batch = batch.context("while reading batch")?; + + // Forward pass with gradients + let output = model.forward_train(batch.inputs, batch.targets); + + train_total += output.targets.dims()[0] as i32; + train_loss_sum += output.loss.clone().into_scalar().to_f32(); + + let grads = output.loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + + model = optim.step(learning_rate, model, grads); } - for (a, b) in tokens.windows(config.context_size).step_by(stride).zip( - tokens[stride..] - .windows(config.context_size) - .step_by(stride), - ) { - input_batch.push(a.to_owned()); - output_batch.push(b.to_owned()); + let mut valid_loss_sum = 0.0; + let mut valid_total = 0; - /* - let context = a; - let desired = &b[b.len() - 1..]; - println!("{context:?} -> {desired:?}"); + let mut n_eval = 0; + debug!("Evaluating batches"); - let input = tokenizer.decode(context); - let target = tokenizer.decode(desired); - println!("{input:?} -> {target:?}"); - */ + for batch in + TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, true, &device) + .context("while initializing reader")? + { + let batch = batch.context("while reading batch")?; + n_eval += batch.targets.shape()[0]; - if input_batch.len() >= batch_size { - let shape = [input_batch.len(), config.context_size]; + // Forward pass without gradients + let output = model.valid().forward_train(batch.inputs, batch.targets); - let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); - let input: Array2 = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); - - #[expect(clippy::unwrap_used)] - let input: Tensor = - Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) - .reshape(shape); - - let output = - std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size)); - let output: Array2 = Array2::from_shape_fn(shape, |(a, b)| output[a][b]); - - #[expect(clippy::unwrap_used)] - let output: Tensor = - Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device) - .reshape(shape); - - self.batch(&config, input, &model); - } + valid_total += output.targets.dims()[0] as i32; + valid_loss_sum += output.loss.into_scalar().to_f32(); } - } - if !input_batch.is_empty() { - let shape = [input_batch.len(), config.context_size]; + // Compute and log epoch results + let train_loss = if train_total > 0 { + train_loss_sum / train_total as f32 + } else { + 0.0 + }; + let valid_loss = if valid_total > 0 { + valid_loss_sum / valid_total as f32 + } else { + 0.0 + }; - let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); - let input: Array2 = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); - - #[expect(clippy::unwrap_used)] - let input: Tensor = - Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape); - - let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size)); - let output: Array2 = Array2::from_shape_fn(shape, |(a, b)| output[a][b]); - - #[expect(clippy::unwrap_used)] - let output: Tensor = - Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape); - - self.batch(&config, input, &model); + info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval); } Ok(()) } - - fn batch(&self, _cfg: &Config, input: Tensor, model: &GptModel) { - let logits = model.forward(input); - println!("{:?}", logits.shape()); - } } +// +// MARK: model +// + /// Multihead attention. /// /// Equivalent to many stacked CausalAttention layers. @@ -315,7 +451,7 @@ impl MultiheadAttention { }, device.clone(), true, - [embedding_dim, total_dim].into(), + [total_dim, total_dim].into(), ), dropout: Dropout { prob: dropout }, @@ -389,6 +525,7 @@ impl MultiheadAttention { let mask = self .utri_mask .clone() + .slice([0..tokens, 0..tokens]) .unsqueeze_dim::<3>(0) .unsqueeze_dim::<4>(0) .expand(attn_scores.shape()); @@ -422,35 +559,50 @@ impl MultiheadAttention { } } -#[derive(Module, Debug)] -pub struct GptModel { - embedder_tok: Embedding, - embedder_pos: Embedding, - embedder_drop: Dropout, +#[derive(Config, Debug)] +pub struct GptModelConfig { + /// Number of tokens + pub vocab_size: u32, - trf_blocks: Vec>, - final_norm: LayerNorm, - out_head: Param>, + /// Maximum number of input tokens with positional embeddings + pub context_size: usize, + + /// Dimension of each token's embedding + pub embed_dim: usize, + + /// Number of attention heads + pub n_heads: usize, + + /// Dimension of each attn head + pub head_dim: usize, + + /// Number of transformer blocks + pub n_layers: usize, + + pub embed_drop_rate: f64, + pub attention_drop_rate: f64, + pub shortcut_drop_rate: f64, } -impl GptModel { - pub fn new(cfg: &Config, device: &B::Device) -> Self { - let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize]; +impl GptModelConfig { + pub fn init(&self, device: &B::Device) -> GptModel { + let out_head_shape = [self.embed_dim, self.vocab_size as usize]; - Self { - embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device), + GptModel { + embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim) + .init(device), - embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device), + embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device), embedder_drop: Dropout { - prob: cfg.embed_drop_rate, + prob: self.embed_drop_rate, }, - trf_blocks: (0..cfg.n_layers) - .map(|_| TransformerBlock::new(cfg, device)) + trf_blocks: (0..self.n_layers) + .map(|_| TransformerBlock::new(&self, device)) .collect(), - final_norm: LayerNormConfig::new(cfg.embed_dim).init(device), + final_norm: LayerNormConfig::new(self.embed_dim).init(device), out_head: Param::uninitialized( ParamId::new(), @@ -464,13 +616,34 @@ impl GptModel { ), } } +} +#[derive(Debug, Clone)] +pub struct TrainBatch { + pub inputs: Tensor, + + /// Correct next token for each input + pub targets: Tensor, +} + +#[derive(Module, Debug)] +pub struct GptModel { + embedder_tok: Embedding, + embedder_pos: Embedding, + embedder_drop: Dropout, + + trf_blocks: Vec>, + final_norm: LayerNorm, + out_head: Param>, +} + +impl GptModel { pub fn forward(&self, input: Tensor) -> Tensor { let n_tokens = input.shape()[1]; let embed_tok = self.embedder_tok.forward(input.clone()); let embed_pos = self - .embedder_tok + .embedder_pos .forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0)); let x = embed_tok + embed_pos; @@ -481,6 +654,29 @@ impl GptModel { return logits; } + + pub fn forward_train( + &self, + inputs: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + // shape: [batch, n_tokens, n_vocabulary] + let output = self.forward(inputs); + + // Get last token + // shape: [batch, n_vocabulary] + let output = output.slice_dim(1, -1).squeeze_dim::<2>(1); + + let loss = CrossEntropyLossConfig::new() + .init(&targets.device()) + .forward(output.clone(), targets.clone()); + + ClassificationOutput { + loss, + output, + targets, + } + } } #[derive(Module, Debug)] @@ -498,7 +694,7 @@ pub struct TransformerBlock { } impl TransformerBlock { - pub fn new(cfg: &Config, device: &B::Device) -> Self { + pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self { Self { attention: MultiheadAttention::new( cfg.embed_dim, diff --git a/crates/llmfs/src/data_reader.rs b/crates/llmfs/src/data_reader.rs index 7e1188d..66dfc5f 100644 --- a/crates/llmfs/src/data_reader.rs +++ b/crates/llmfs/src/data_reader.rs @@ -25,10 +25,11 @@ pub enum DataReaderError { /// /// All parquet files have exactly one text column. /// No promises about this struct's behavior if this is not the case. +#[derive(Clone)] pub struct DataReader { - rx: Receiver>, + rx: Arc>>>, total_rows: usize, - consumed_rows: AtomicUsize, + consumed_rows: Arc, } impl DataReader { @@ -57,6 +58,15 @@ impl DataReader { files.push(path); } } + + files.sort_by_key(|a| { + a.file_name() + .map(|x| x.to_str()) + .flatten() + .unwrap_or("") + .to_owned() + }); + (Arc::new(Mutex::new(files)), total_rows) }; @@ -147,9 +157,9 @@ impl DataReader { }); Ok(Self { - rx, + rx: Arc::new(Mutex::new(rx)), total_rows, - consumed_rows: AtomicUsize::new(0), + consumed_rows: Arc::new(AtomicUsize::new(0)), }) } @@ -157,7 +167,7 @@ impl DataReader { /// Order is arbitrary. /// Returns `None` when all rows have been read. pub fn recv(&self) -> Option> { - self.rx.recv().ok() + self.rx.lock().recv().ok() } //pub fn try_recv(&self) -> Result, TryRecvError> { diff --git a/crates/llmfs/src/logging.rs b/crates/llmfs/src/logging.rs index 7660b9e..9b07eaf 100644 --- a/crates/llmfs/src/logging.rs +++ b/crates/llmfs/src/logging.rs @@ -12,18 +12,16 @@ use tracing_subscriber::{ // MARK: loglevel // -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)] -#[derive(Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum, Default)] pub enum LogLevel { Trace, Debug, #[default] - Info, + Info, Warn, Error, } - impl Display for LogLevel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -71,7 +69,7 @@ impl From for EnvFilter { // // Bins // - format!("nanochat_rs={}", conf.nanochat), + format!("llmfs={}", conf.nanochat), conf.other.to_string(), ] .join(","), @@ -216,16 +214,14 @@ pub enum LoggingTarget { } /// How to print logs -#[derive(Debug, Clone, Copy, Deserialize)] -#[derive(Default)] +#[derive(Debug, Clone, Copy, Deserialize, Default)] pub enum LoggingFormat { #[default] - Ansi, + Ansi, AnsiNoColor, Json, } - pub struct LoggingInitializer { /// Log filter for printed logs pub preset: LogFilterPreset, diff --git a/crates/llmfs/src/main.rs b/crates/llmfs/src/main.rs index 2297c77..e92596e 100644 --- a/crates/llmfs/src/main.rs +++ b/crates/llmfs/src/main.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "256"] + use clap::Parser; use indicatif::MultiProgress; use tracing::error;