TMP
This commit is contained in:
193
Cargo.lock
generated
193
Cargo.lock
generated
@@ -451,6 +451,7 @@ dependencies = [
|
|||||||
"ahash",
|
"ahash",
|
||||||
"bincode",
|
"bincode",
|
||||||
"burn-common",
|
"burn-common",
|
||||||
|
"burn-dataset",
|
||||||
"burn-derive",
|
"burn-derive",
|
||||||
"burn-tensor",
|
"burn-tensor",
|
||||||
"data-encoding",
|
"data-encoding",
|
||||||
@@ -546,6 +547,25 @@ dependencies = [
|
|||||||
"log",
|
"log",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "burn-dataset"
|
||||||
|
version = "0.19.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "534d4398fd6aaec32f8caeb3f20ddffcd8a059bdefc01cc2794b91b4e984e8ea"
|
||||||
|
dependencies = [
|
||||||
|
"csv",
|
||||||
|
"derive-new",
|
||||||
|
"dirs",
|
||||||
|
"rand",
|
||||||
|
"rmp-serde",
|
||||||
|
"sanitize-filename",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"strum",
|
||||||
|
"tempfile",
|
||||||
|
"thiserror 2.0.17",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "burn-derive"
|
name = "burn-derive"
|
||||||
version = "0.19.1"
|
version = "0.19.1"
|
||||||
@@ -598,8 +618,10 @@ dependencies = [
|
|||||||
"burn-common",
|
"burn-common",
|
||||||
"burn-ir",
|
"burn-ir",
|
||||||
"burn-tensor",
|
"burn-tensor",
|
||||||
|
"bytemuck",
|
||||||
"const-random",
|
"const-random",
|
||||||
"derive-new",
|
"derive-new",
|
||||||
|
"itertools 0.14.0",
|
||||||
"libm",
|
"libm",
|
||||||
"macerator",
|
"macerator",
|
||||||
"matrixmultiply",
|
"matrixmultiply",
|
||||||
@@ -608,6 +630,7 @@ dependencies = [
|
|||||||
"paste",
|
"paste",
|
||||||
"portable-atomic-util",
|
"portable-atomic-util",
|
||||||
"rand",
|
"rand",
|
||||||
|
"seq-macro",
|
||||||
"spin",
|
"spin",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -703,6 +726,25 @@ dependencies = [
|
|||||||
"serde_bytes",
|
"serde_bytes",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "burn-train"
|
||||||
|
version = "0.19.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b0f1553197d50668823a4bafc187c62439df49b218973f0ca79e034b57ce38d6"
|
||||||
|
dependencies = [
|
||||||
|
"async-channel",
|
||||||
|
"burn-core",
|
||||||
|
"burn-ndarray",
|
||||||
|
"burn-optim",
|
||||||
|
"derive-new",
|
||||||
|
"log",
|
||||||
|
"rstest",
|
||||||
|
"serde",
|
||||||
|
"tracing-appender",
|
||||||
|
"tracing-core",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "burn-wgpu"
|
name = "burn-wgpu"
|
||||||
version = "0.19.1"
|
version = "0.19.1"
|
||||||
@@ -1057,6 +1099,15 @@ version = "1.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
|
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.5.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-deque"
|
name = "crossbeam-deque"
|
||||||
version = "0.8.6"
|
version = "0.8.6"
|
||||||
@@ -1098,6 +1149,27 @@ dependencies = [
|
|||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "csv"
|
||||||
|
version = "1.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938"
|
||||||
|
dependencies = [
|
||||||
|
"csv-core",
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "csv-core"
|
||||||
|
version = "0.1.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl"
|
name = "cubecl"
|
||||||
version = "0.8.1"
|
version = "0.8.1"
|
||||||
@@ -1573,6 +1645,15 @@ version = "2.9.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "deranged"
|
||||||
|
version = "0.5.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587"
|
||||||
|
dependencies = [
|
||||||
|
"powerfmt",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "derive-new"
|
name = "derive-new"
|
||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
@@ -2064,6 +2145,12 @@ version = "0.3.31"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-timer"
|
||||||
|
version = "3.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-util"
|
name = "futures-util"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@@ -3044,15 +3131,18 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092"
|
|||||||
name = "llmfs"
|
name = "llmfs"
|
||||||
version = "0.0.1"
|
version = "0.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
"anstyle",
|
"anstyle",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"burn",
|
"burn",
|
||||||
|
"burn-train",
|
||||||
"clap",
|
"clap",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"indicatif",
|
"indicatif",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"parquet",
|
"parquet",
|
||||||
|
"rand",
|
||||||
"rayon",
|
"rayon",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -3149,7 +3239,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
|
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
|
"num_cpus",
|
||||||
|
"once_cell",
|
||||||
"rawpointer",
|
"rawpointer",
|
||||||
|
"thread-tree",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3299,6 +3392,7 @@ dependencies = [
|
|||||||
"portable-atomic",
|
"portable-atomic",
|
||||||
"portable-atomic-util",
|
"portable-atomic-util",
|
||||||
"rawpointer",
|
"rawpointer",
|
||||||
|
"rayon",
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -3379,6 +3473,12 @@ dependencies = [
|
|||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-conv"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-integer"
|
name = "num-integer"
|
||||||
version = "0.1.46"
|
version = "0.1.46"
|
||||||
@@ -3689,6 +3789,12 @@ dependencies = [
|
|||||||
"zerovec",
|
"zerovec",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "powerfmt"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
version = "0.2.21"
|
version = "0.2.21"
|
||||||
@@ -3994,6 +4100,12 @@ version = "0.8.8"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
|
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "relative-path"
|
||||||
|
version = "1.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "renderdoc-sys"
|
name = "renderdoc-sys"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -4084,6 +4196,35 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rstest"
|
||||||
|
version = "0.26.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49"
|
||||||
|
dependencies = [
|
||||||
|
"futures-timer",
|
||||||
|
"futures-util",
|
||||||
|
"rstest_macros",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rstest_macros"
|
||||||
|
version = "0.26.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"glob",
|
||||||
|
"proc-macro-crate",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"relative-path",
|
||||||
|
"rustc_version",
|
||||||
|
"syn",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustc-hash"
|
name = "rustc-hash"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -4672,6 +4813,15 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thread-tree"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-channel",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thread_local"
|
name = "thread_local"
|
||||||
version = "1.1.9"
|
version = "1.1.9"
|
||||||
@@ -4692,6 +4842,37 @@ dependencies = [
|
|||||||
"ordered-float 2.10.1",
|
"ordered-float 2.10.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "time"
|
||||||
|
version = "0.3.44"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d"
|
||||||
|
dependencies = [
|
||||||
|
"deranged",
|
||||||
|
"itoa",
|
||||||
|
"num-conv",
|
||||||
|
"powerfmt",
|
||||||
|
"serde",
|
||||||
|
"time-core",
|
||||||
|
"time-macros",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "time-core"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "time-macros"
|
||||||
|
version = "0.2.24"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3"
|
||||||
|
dependencies = [
|
||||||
|
"num-conv",
|
||||||
|
"time-core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tiny-keccak"
|
name = "tiny-keccak"
|
||||||
version = "2.0.2"
|
version = "2.0.2"
|
||||||
@@ -4989,6 +5170,18 @@ dependencies = [
|
|||||||
"tracing-core",
|
"tracing-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-appender"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-channel",
|
||||||
|
"thiserror 2.0.17",
|
||||||
|
"time",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-attributes"
|
name = "tracing-attributes"
|
||||||
version = "0.1.31"
|
version = "0.1.31"
|
||||||
|
|||||||
@@ -75,10 +75,12 @@ compact_str = "0.9.0"
|
|||||||
dary_heap = "0.3.8"
|
dary_heap = "0.3.8"
|
||||||
fancy-regex = "0.16.2"
|
fancy-regex = "0.16.2"
|
||||||
indicatif = { version = "0.18.3", features = ["improved_unicode"] }
|
indicatif = { version = "0.18.3", features = ["improved_unicode"] }
|
||||||
|
itertools = "0.14.0"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
ndarray = { version = "0.16.1", features = ["serde"] }
|
ndarray = { version = "0.16.1", features = ["serde"] }
|
||||||
parking_lot = "0.12.5"
|
parking_lot = "0.12.5"
|
||||||
parquet = "56.2.0"
|
parquet = "56.2.0"
|
||||||
|
rand = "0.9.2"
|
||||||
rayon = "1.11.0"
|
rayon = "1.11.0"
|
||||||
reqwest = { version = "0.12.24", features = ["json", "stream"] }
|
reqwest = { version = "0.12.24", features = ["json", "stream"] }
|
||||||
serde = "1.0.228"
|
serde = "1.0.228"
|
||||||
@@ -91,7 +93,10 @@ tracing-indicatif = "0.3.13"
|
|||||||
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
|
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
|
||||||
url = "2.5.7"
|
url = "2.5.7"
|
||||||
|
|
||||||
|
|
||||||
|
burn-train = { version = "0.19.1", default-features = false }
|
||||||
|
|
||||||
[workspace.dependencies.burn]
|
[workspace.dependencies.burn]
|
||||||
version = "0.19.1"
|
version = "0.19.1"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = ["std", "fusion", "ndarray", "webgpu", "cuda"]
|
features = ["std", "fusion", "ndarray", "webgpu", "cuda", "autodiff"]
|
||||||
|
|||||||
@@ -10,15 +10,18 @@ workspace = true
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
tokenizer = { workspace = true }
|
tokenizer = { workspace = true }
|
||||||
|
|
||||||
|
ahash = { workspace = true }
|
||||||
anstyle = { workspace = true }
|
anstyle = { workspace = true }
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
burn = { workspace = true }
|
burn = { workspace = true }
|
||||||
|
burn-train = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
futures-util = { workspace = true }
|
futures-util = { workspace = true }
|
||||||
indicatif = { workspace = true }
|
indicatif = { workspace = true }
|
||||||
ndarray = { workspace = true }
|
ndarray = { workspace = true }
|
||||||
parking_lot = { workspace = true }
|
parking_lot = { workspace = true }
|
||||||
parquet = { workspace = true }
|
parquet = { workspace = true }
|
||||||
|
rand = { workspace = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ pub enum SubCommand {
|
|||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
args: train_tokenizer::TrainTokenizerArgs,
|
args: train_tokenizer::TrainTokenizerArgs,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Sample data
|
/// Sample data
|
||||||
SampleData {
|
SampleData {
|
||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
|
|||||||
@@ -1,22 +1,206 @@
|
|||||||
|
use ahash::AHasher;
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use burn::{
|
use burn::{
|
||||||
Tensor,
|
Tensor,
|
||||||
backend::{Cuda, cuda::CudaDevice},
|
backend::{Autodiff, Cuda, Wgpu, cuda::CudaDevice, wgpu::WgpuDevice},
|
||||||
module::{Module, Param, ParamId},
|
config::Config,
|
||||||
|
module::{AutodiffModule, Module, Param, ParamId},
|
||||||
nn::{
|
nn::{
|
||||||
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
|
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
|
||||||
|
loss::CrossEntropyLossConfig,
|
||||||
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
|
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
|
||||||
},
|
},
|
||||||
prelude::Backend,
|
optim::{AdamConfig, GradientsParams, Optimizer},
|
||||||
|
prelude::{Backend, ToElement},
|
||||||
tensor::{Bool, Distribution, Int, activation::softmax},
|
tensor::{Bool, Distribution, Int, activation::softmax},
|
||||||
};
|
};
|
||||||
|
use burn_train::ClassificationOutput;
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
use indicatif::MultiProgress;
|
use indicatif::{MultiProgress, ProgressIterator};
|
||||||
use ndarray::Array2;
|
use ndarray::{Array1, Array2};
|
||||||
use std::{f32, fs::File, path::PathBuf};
|
use std::{
|
||||||
|
collections::VecDeque,
|
||||||
|
f32,
|
||||||
|
fs::File,
|
||||||
|
hash::Hasher,
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
};
|
||||||
use tokenizer::Tokenizer;
|
use tokenizer::Tokenizer;
|
||||||
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use crate::data_reader::DataReader;
|
use crate::data_reader::{DataReader, DataReaderError};
|
||||||
|
|
||||||
|
// Text generation routine
|
||||||
|
|
||||||
|
/*
|
||||||
|
{
|
||||||
|
let init = "Initial context. This is ";
|
||||||
|
let tokens = tokenizer.encode(&init);
|
||||||
|
|
||||||
|
let n_tokens = tokens.len();
|
||||||
|
let input: Array1<u32> = Array1::from_vec(tokens);
|
||||||
|
let mut input: Tensor<Cuda, 1, Int> =
|
||||||
|
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||||
|
.reshape([n_tokens]);
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
|
||||||
|
println!("{:?}", tokens);
|
||||||
|
println!("{}", tokenizer.decode(&tokens));
|
||||||
|
|
||||||
|
// Crop idx to context size;
|
||||||
|
let batch = input
|
||||||
|
.clone()
|
||||||
|
.slice([0..config.context_size])
|
||||||
|
.unsqueeze_dim(0);
|
||||||
|
|
||||||
|
// shape: [tokens, vocab_size]
|
||||||
|
let logits = model.forward(batch).squeeze_dim::<2>(0);
|
||||||
|
|
||||||
|
// shape: [vocab_size]
|
||||||
|
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
|
||||||
|
|
||||||
|
let probs = softmax(logits, 0); // shape: [n_tokens]
|
||||||
|
let id_next = probs.argmax(0); // shape: [1]
|
||||||
|
|
||||||
|
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct TrainTestIterator<'a, B: Backend> {
|
||||||
|
reader: DataReader,
|
||||||
|
|
||||||
|
ccfg: &'a ComputeConfig,
|
||||||
|
mcfg: &'a GptModelConfig,
|
||||||
|
tokenizer: &'a Tokenizer,
|
||||||
|
eval: bool,
|
||||||
|
device: &'a B::Device,
|
||||||
|
|
||||||
|
error: bool,
|
||||||
|
|
||||||
|
// Tokenized input/output pairs
|
||||||
|
pairs: VecDeque<(Vec<u32>, u32)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, B: Backend> TrainTestIterator<'a, B> {
|
||||||
|
pub fn new(
|
||||||
|
data_dir: impl AsRef<Path>,
|
||||||
|
ccfg: &'a ComputeConfig,
|
||||||
|
mcfg: &'a GptModelConfig,
|
||||||
|
tokenizer: &'a Tokenizer,
|
||||||
|
eval: bool,
|
||||||
|
device: &'a B::Device,
|
||||||
|
) -> Result<Self, std::io::Error> {
|
||||||
|
let reader = DataReader::new(3, data_dir)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
reader,
|
||||||
|
ccfg,
|
||||||
|
mcfg,
|
||||||
|
tokenizer,
|
||||||
|
eval,
|
||||||
|
device,
|
||||||
|
|
||||||
|
error: false,
|
||||||
|
pairs: VecDeque::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
|
||||||
|
type Item = Result<TrainBatch<B>, DataReaderError>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.error {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut inputs = Vec::with_capacity(self.ccfg.batch_size);
|
||||||
|
let mut targets = Vec::with_capacity(self.ccfg.batch_size);
|
||||||
|
let stride = self.mcfg.context_size;
|
||||||
|
|
||||||
|
while inputs.len() < self.ccfg.batch_size {
|
||||||
|
match self.pairs.pop_front() {
|
||||||
|
Some((i, t)) => {
|
||||||
|
// train/test split
|
||||||
|
{
|
||||||
|
let mut hasher = AHasher::default();
|
||||||
|
hasher.write(self.ccfg.eval_salt.as_bytes());
|
||||||
|
|
||||||
|
// Don't care about endianness, ahash output is unstable anyway
|
||||||
|
hasher.write(unsafe { std::mem::transmute(&i[..]) });
|
||||||
|
hasher.write_u32(t);
|
||||||
|
|
||||||
|
let test = // is this point in the test set?
|
||||||
|
hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64();
|
||||||
|
|
||||||
|
if test ^ self.eval {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs.push(i);
|
||||||
|
targets.push(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
None => {
|
||||||
|
let text = match self.reader.next() {
|
||||||
|
None => break,
|
||||||
|
Some(Ok(x)) => x,
|
||||||
|
Some(Err(x)) => {
|
||||||
|
self.error = true;
|
||||||
|
return Some(Err(x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let emb = self.tokenizer.encode(&text);
|
||||||
|
|
||||||
|
// Skip small texts
|
||||||
|
//
|
||||||
|
// TODO: do this better
|
||||||
|
// TODO: maybe using <|bos|>?
|
||||||
|
// TODO: non-uniform batches?
|
||||||
|
if emb.len() < self.mcfg.context_size {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let pairs = emb
|
||||||
|
.windows(self.mcfg.context_size + 1)
|
||||||
|
.step_by(stride)
|
||||||
|
.map(|x| {
|
||||||
|
(
|
||||||
|
x[..self.mcfg.context_size].to_vec(),
|
||||||
|
x[self.mcfg.context_size],
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
self.pairs.extend(pairs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if inputs.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let shape = [inputs.len(), self.mcfg.context_size];
|
||||||
|
|
||||||
|
// Arrange data in memory
|
||||||
|
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
|
||||||
|
let targets: Array1<u32> = Array1::from_vec(targets);
|
||||||
|
|
||||||
|
// Create tensors on gpu
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let inputs =
|
||||||
|
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
|
||||||
|
|
||||||
|
return Some(Ok(TrainBatch { inputs, targets }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Args, Clone)]
|
#[derive(Debug, Args, Clone)]
|
||||||
|
|
||||||
@@ -32,190 +216,144 @@ pub struct SampleDataArgs {
|
|||||||
/// How many texts to return
|
/// How many texts to return
|
||||||
#[clap(long, short = 'n', default_value = "10")]
|
#[clap(long, short = 'n', default_value = "10")]
|
||||||
n: usize,
|
n: usize,
|
||||||
|
|
||||||
/// How many texts to skip
|
|
||||||
#[clap(long, short = 's', default_value = "0")]
|
|
||||||
skip: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
pub struct ComputeConfig {
|
||||||
pub struct Config {
|
pub batch_size: usize,
|
||||||
/// Number of tokens
|
pub eval_frac: f64,
|
||||||
pub vocab_size: u32,
|
pub eval_salt: String,
|
||||||
|
|
||||||
/// Maximum number of input tokens with positional embeddings
|
|
||||||
pub context_size: usize,
|
|
||||||
|
|
||||||
/// Dimension of each token's embedding
|
|
||||||
pub embed_dim: usize,
|
|
||||||
|
|
||||||
/// Number of attention heads
|
|
||||||
pub n_heads: usize,
|
|
||||||
|
|
||||||
/// Dimension of each attn head
|
|
||||||
pub head_dim: usize,
|
|
||||||
|
|
||||||
/// Number of transformer blocks
|
|
||||||
pub n_layers: usize,
|
|
||||||
|
|
||||||
pub embed_drop_rate: f64,
|
|
||||||
pub attention_drop_rate: f64,
|
|
||||||
pub shortcut_drop_rate: f64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SampleDataArgs {
|
impl SampleDataArgs {
|
||||||
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
|
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
|
||||||
let device = CudaDevice::new(0);
|
let device = CudaDevice::new(0);
|
||||||
|
//let device = WgpuDevice::DiscreteGpu(0);
|
||||||
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
|
|
||||||
|
|
||||||
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
|
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
|
||||||
let tokenizer: Tokenizer =
|
let tokenizer: Tokenizer =
|
||||||
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
|
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
|
||||||
|
|
||||||
let config = Config {
|
let ccfg = ComputeConfig {
|
||||||
|
batch_size: 10,
|
||||||
|
eval_frac: 0.1,
|
||||||
|
eval_salt: "salt".into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mcfg = GptModelConfig {
|
||||||
vocab_size: tokenizer.vocab_size(),
|
vocab_size: tokenizer.vocab_size(),
|
||||||
context_size: 4,
|
context_size: 256,
|
||||||
embed_dim: 768,
|
embed_dim: 768,
|
||||||
n_heads: 12,
|
n_heads: 12,
|
||||||
head_dim: 64, // = 768 / 12
|
head_dim: 64, // = 768 / 12
|
||||||
n_layers: 12,
|
n_layers: 1,
|
||||||
embed_drop_rate: 0.1,
|
embed_drop_rate: 0.1,
|
||||||
attention_drop_rate: 0.1,
|
attention_drop_rate: 0.1,
|
||||||
shortcut_drop_rate: 0.1,
|
shortcut_drop_rate: 0.1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let stride = config.context_size;
|
let mut model: GptModel<Autodiff<Cuda>> = mcfg.init(&device);
|
||||||
|
|
||||||
let batch_size = 10;
|
|
||||||
let mut input_batch = Vec::with_capacity(batch_size);
|
|
||||||
let mut output_batch = Vec::with_capacity(batch_size);
|
|
||||||
|
|
||||||
#[expect(clippy::unwrap_used)] // Lazy error handling
|
|
||||||
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
|
|
||||||
|
|
||||||
let model = GptModel::new(&config, &device);
|
|
||||||
|
|
||||||
// Text generation routine
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
{
|
let loader_train = DataLoaderBuilder::new(batcher.clone())
|
||||||
let init = "Initial context. This is ";
|
.batch_size(ccfg.batch_size)
|
||||||
let tokens = tokenizer.encode(&init);
|
//.shuffle(config.seed)
|
||||||
|
.num_workers(5)
|
||||||
|
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
|
||||||
|
|
||||||
let n_tokens = tokens.len();
|
let loader_test = DataLoaderBuilder::new(batcher)
|
||||||
let input: Array1<u32> = Array1::from_vec(tokens);
|
.batch_size(ccfg.batch_size)
|
||||||
let mut input: Tensor<Cuda, 1, Int> =
|
//.shuffle(config.seed)
|
||||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
.num_workers(5)
|
||||||
.reshape([n_tokens]);
|
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
|
||||||
|
|
||||||
for _ in 0..100 {
|
let learner = LearnerBuilder::new("./tmp")
|
||||||
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
println!("{:?}", tokens);
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
println!("{}", tokenizer.decode(&tokens));
|
.metric_train_numeric(LossMetric::new())
|
||||||
|
.metric_valid_numeric(LossMetric::new())
|
||||||
|
.with_file_checkpointer(CompactRecorder::new())
|
||||||
|
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
|
||||||
|
.num_epochs(10)
|
||||||
|
.summary()
|
||||||
|
.build(model, AdamConfig::new().init(), 1e-4);
|
||||||
|
|
||||||
// Crop idx to context size;
|
learner.fit(loader_train, loader_test);
|
||||||
let batch = input
|
|
||||||
.clone()
|
|
||||||
.slice([0..config.context_size])
|
|
||||||
.unsqueeze_dim(0);
|
|
||||||
|
|
||||||
// shape: [tokens, vocab_size]
|
|
||||||
let logits = model.forward(batch).squeeze_dim::<2>(0);
|
|
||||||
|
|
||||||
// shape: [vocab_size]
|
|
||||||
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
|
|
||||||
|
|
||||||
let probs = softmax(logits, 0); // shape: [n_tokens]
|
|
||||||
let id_next = probs.argmax(0); // shape: [1]
|
|
||||||
|
|
||||||
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
for i in iter {
|
// Initialize optimizer
|
||||||
let tokens = tokenizer.encode(&i);
|
let mut optim = AdamConfig::new().init();
|
||||||
|
let learning_rate = 1e-4;
|
||||||
|
|
||||||
// Skip small texts.
|
for epoch in 0..10 {
|
||||||
// TODO: do this better
|
debug!("Running epoch {epoch}");
|
||||||
// TODO: non-uniform batches?
|
|
||||||
if tokens.len() < config.context_size {
|
// Training phase
|
||||||
continue;
|
let mut train_loss_sum = 0.0;
|
||||||
|
let mut train_total = 0;
|
||||||
|
|
||||||
|
for batch in
|
||||||
|
TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, false, &device)
|
||||||
|
.context("while initializing reader")?
|
||||||
|
{
|
||||||
|
let batch = batch.context("while reading batch")?;
|
||||||
|
|
||||||
|
// Forward pass with gradients
|
||||||
|
let output = model.forward_train(batch.inputs, batch.targets);
|
||||||
|
|
||||||
|
train_total += output.targets.dims()[0] as i32;
|
||||||
|
train_loss_sum += output.loss.clone().into_scalar().to_f32();
|
||||||
|
|
||||||
|
debug!("Running backward pass");
|
||||||
|
let grads = output.loss.backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
|
||||||
|
debug!("Running optimizer step");
|
||||||
|
model = optim.step(learning_rate, model, grads);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (a, b) in tokens.windows(config.context_size).step_by(stride).zip(
|
let mut valid_loss_sum = 0.0;
|
||||||
tokens[stride..]
|
let mut valid_total = 0;
|
||||||
.windows(config.context_size)
|
|
||||||
.step_by(stride),
|
|
||||||
) {
|
|
||||||
input_batch.push(a.to_owned());
|
|
||||||
output_batch.push(b.to_owned());
|
|
||||||
|
|
||||||
/*
|
let mut n_eval = 0;
|
||||||
let context = a;
|
debug!("Evaluating batches");
|
||||||
let desired = &b[b.len() - 1..];
|
|
||||||
println!("{context:?} -> {desired:?}");
|
|
||||||
|
|
||||||
let input = tokenizer.decode(context);
|
for batch in
|
||||||
let target = tokenizer.decode(desired);
|
TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, true, &device)
|
||||||
println!("{input:?} -> {target:?}");
|
.context("while initializing reader")?
|
||||||
*/
|
{
|
||||||
|
let batch = batch.context("while reading batch")?;
|
||||||
|
n_eval += batch.targets.shape()[0];
|
||||||
|
|
||||||
if input_batch.len() >= batch_size {
|
// Forward pass without gradients
|
||||||
let shape = [input_batch.len(), config.context_size];
|
let output = model.valid().forward_train(batch.inputs, batch.targets);
|
||||||
|
|
||||||
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
|
valid_total += output.targets.dims()[0] as i32;
|
||||||
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
|
valid_loss_sum += output.loss.into_scalar().to_f32();
|
||||||
|
|
||||||
#[expect(clippy::unwrap_used)]
|
|
||||||
let input: Tensor<Cuda, 2, Int> =
|
|
||||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
|
||||||
.reshape(shape);
|
|
||||||
|
|
||||||
let output =
|
|
||||||
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
|
|
||||||
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
|
|
||||||
|
|
||||||
#[expect(clippy::unwrap_used)]
|
|
||||||
let output: Tensor<Cuda, 2, Int> =
|
|
||||||
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
|
|
||||||
.reshape(shape);
|
|
||||||
|
|
||||||
self.batch(&config, input, &model);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !input_batch.is_empty() {
|
// Compute and log epoch results
|
||||||
let shape = [input_batch.len(), config.context_size];
|
let train_loss = if train_total > 0 {
|
||||||
|
train_loss_sum / train_total as f32
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let valid_loss = if valid_total > 0 {
|
||||||
|
valid_loss_sum / valid_total as f32
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
|
info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval);
|
||||||
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
|
|
||||||
|
|
||||||
#[expect(clippy::unwrap_used)]
|
|
||||||
let input: Tensor<Cuda, 2, Int> =
|
|
||||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape);
|
|
||||||
|
|
||||||
let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
|
|
||||||
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
|
|
||||||
|
|
||||||
#[expect(clippy::unwrap_used)]
|
|
||||||
let output: Tensor<Cuda, 2, Int> =
|
|
||||||
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape);
|
|
||||||
|
|
||||||
self.batch(&config, input, &model);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn batch(&self, _cfg: &Config, input: Tensor<Cuda, 2, Int>, model: &GptModel<Cuda>) {
|
|
||||||
let logits = model.forward(input);
|
|
||||||
println!("{:?}", logits.shape());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: model
|
||||||
|
//
|
||||||
|
|
||||||
/// Multihead attention.
|
/// Multihead attention.
|
||||||
///
|
///
|
||||||
/// Equivalent to many stacked CausalAttention layers.
|
/// Equivalent to many stacked CausalAttention layers.
|
||||||
@@ -315,7 +453,7 @@ impl<B: Backend> MultiheadAttention<B> {
|
|||||||
},
|
},
|
||||||
device.clone(),
|
device.clone(),
|
||||||
true,
|
true,
|
||||||
[embedding_dim, total_dim].into(),
|
[total_dim, total_dim].into(),
|
||||||
),
|
),
|
||||||
|
|
||||||
dropout: Dropout { prob: dropout },
|
dropout: Dropout { prob: dropout },
|
||||||
@@ -389,6 +527,7 @@ impl<B: Backend> MultiheadAttention<B> {
|
|||||||
let mask = self
|
let mask = self
|
||||||
.utri_mask
|
.utri_mask
|
||||||
.clone()
|
.clone()
|
||||||
|
.slice([0..tokens, 0..tokens])
|
||||||
.unsqueeze_dim::<3>(0)
|
.unsqueeze_dim::<3>(0)
|
||||||
.unsqueeze_dim::<4>(0)
|
.unsqueeze_dim::<4>(0)
|
||||||
.expand(attn_scores.shape());
|
.expand(attn_scores.shape());
|
||||||
@@ -422,35 +561,50 @@ impl<B: Backend> MultiheadAttention<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Config, Debug)]
|
||||||
pub struct GptModel<B: Backend> {
|
pub struct GptModelConfig {
|
||||||
embedder_tok: Embedding<B>,
|
/// Number of tokens
|
||||||
embedder_pos: Embedding<B>,
|
pub vocab_size: u32,
|
||||||
embedder_drop: Dropout,
|
|
||||||
|
|
||||||
trf_blocks: Vec<TransformerBlock<B>>,
|
/// Maximum number of input tokens with positional embeddings
|
||||||
final_norm: LayerNorm<B>,
|
pub context_size: usize,
|
||||||
out_head: Param<Tensor<B, 2>>,
|
|
||||||
|
/// Dimension of each token's embedding
|
||||||
|
pub embed_dim: usize,
|
||||||
|
|
||||||
|
/// Number of attention heads
|
||||||
|
pub n_heads: usize,
|
||||||
|
|
||||||
|
/// Dimension of each attn head
|
||||||
|
pub head_dim: usize,
|
||||||
|
|
||||||
|
/// Number of transformer blocks
|
||||||
|
pub n_layers: usize,
|
||||||
|
|
||||||
|
pub embed_drop_rate: f64,
|
||||||
|
pub attention_drop_rate: f64,
|
||||||
|
pub shortcut_drop_rate: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> GptModel<B> {
|
impl GptModelConfig {
|
||||||
pub fn new(cfg: &Config, device: &B::Device) -> Self {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
|
||||||
let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize];
|
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
|
||||||
|
|
||||||
Self {
|
GptModel {
|
||||||
embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device),
|
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
|
||||||
|
.init(device),
|
||||||
|
|
||||||
embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device),
|
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
|
||||||
|
|
||||||
embedder_drop: Dropout {
|
embedder_drop: Dropout {
|
||||||
prob: cfg.embed_drop_rate,
|
prob: self.embed_drop_rate,
|
||||||
},
|
},
|
||||||
|
|
||||||
trf_blocks: (0..cfg.n_layers)
|
trf_blocks: (0..self.n_layers)
|
||||||
.map(|_| TransformerBlock::new(cfg, device))
|
.map(|_| TransformerBlock::new(&self, device))
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
||||||
final_norm: LayerNormConfig::new(cfg.embed_dim).init(device),
|
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
|
||||||
|
|
||||||
out_head: Param::uninitialized(
|
out_head: Param::uninitialized(
|
||||||
ParamId::new(),
|
ParamId::new(),
|
||||||
@@ -464,13 +618,34 @@ impl<B: Backend> GptModel<B> {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TrainBatch<B: Backend> {
|
||||||
|
pub inputs: Tensor<B, 2, Int>,
|
||||||
|
|
||||||
|
/// Correct next token for each input
|
||||||
|
pub targets: Tensor<B, 1, Int>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct GptModel<B: Backend> {
|
||||||
|
embedder_tok: Embedding<B>,
|
||||||
|
embedder_pos: Embedding<B>,
|
||||||
|
embedder_drop: Dropout,
|
||||||
|
|
||||||
|
trf_blocks: Vec<TransformerBlock<B>>,
|
||||||
|
final_norm: LayerNorm<B>,
|
||||||
|
out_head: Param<Tensor<B, 2>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> GptModel<B> {
|
||||||
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||||
let n_tokens = input.shape()[1];
|
let n_tokens = input.shape()[1];
|
||||||
|
|
||||||
let embed_tok = self.embedder_tok.forward(input.clone());
|
let embed_tok = self.embedder_tok.forward(input.clone());
|
||||||
let embed_pos = self
|
let embed_pos = self
|
||||||
.embedder_tok
|
.embedder_pos
|
||||||
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
|
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
|
||||||
|
|
||||||
let x = embed_tok + embed_pos;
|
let x = embed_tok + embed_pos;
|
||||||
@@ -481,6 +656,29 @@ impl<B: Backend> GptModel<B> {
|
|||||||
|
|
||||||
return logits;
|
return logits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward_train(
|
||||||
|
&self,
|
||||||
|
inputs: Tensor<B, 2, Int>,
|
||||||
|
targets: Tensor<B, 1, Int>,
|
||||||
|
) -> ClassificationOutput<B> {
|
||||||
|
// shape: [batch, n_tokens, n_vocabulary]
|
||||||
|
let output = self.forward(inputs);
|
||||||
|
|
||||||
|
// Get last token
|
||||||
|
// shape: [batch, n_vocabulary]
|
||||||
|
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
|
||||||
|
|
||||||
|
let loss = CrossEntropyLossConfig::new()
|
||||||
|
.init(&targets.device())
|
||||||
|
.forward(output.clone(), targets.clone());
|
||||||
|
|
||||||
|
ClassificationOutput {
|
||||||
|
loss,
|
||||||
|
output,
|
||||||
|
targets,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
@@ -498,7 +696,7 @@ pub struct TransformerBlock<B: Backend> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> TransformerBlock<B> {
|
impl<B: Backend> TransformerBlock<B> {
|
||||||
pub fn new(cfg: &Config, device: &B::Device) -> Self {
|
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
|
||||||
Self {
|
Self {
|
||||||
attention: MultiheadAttention::new(
|
attention: MultiheadAttention::new(
|
||||||
cfg.embed_dim,
|
cfg.embed_dim,
|
||||||
|
|||||||
@@ -25,10 +25,11 @@ pub enum DataReaderError {
|
|||||||
///
|
///
|
||||||
/// All parquet files have exactly one text column.
|
/// All parquet files have exactly one text column.
|
||||||
/// No promises about this struct's behavior if this is not the case.
|
/// No promises about this struct's behavior if this is not the case.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct DataReader {
|
pub struct DataReader {
|
||||||
rx: Receiver<Result<String, DataReaderError>>,
|
rx: Arc<Mutex<Receiver<Result<String, DataReaderError>>>>,
|
||||||
total_rows: usize,
|
total_rows: usize,
|
||||||
consumed_rows: AtomicUsize,
|
consumed_rows: Arc<AtomicUsize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DataReader {
|
impl DataReader {
|
||||||
@@ -57,6 +58,15 @@ impl DataReader {
|
|||||||
files.push(path);
|
files.push(path);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
files.sort_by_key(|a| {
|
||||||
|
a.file_name()
|
||||||
|
.map(|x| x.to_str())
|
||||||
|
.flatten()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_owned()
|
||||||
|
});
|
||||||
|
|
||||||
(Arc::new(Mutex::new(files)), total_rows)
|
(Arc::new(Mutex::new(files)), total_rows)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -147,9 +157,9 @@ impl DataReader {
|
|||||||
});
|
});
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
rx,
|
rx: Arc::new(Mutex::new(rx)),
|
||||||
total_rows,
|
total_rows,
|
||||||
consumed_rows: AtomicUsize::new(0),
|
consumed_rows: Arc::new(AtomicUsize::new(0)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +167,7 @@ impl DataReader {
|
|||||||
/// Order is arbitrary.
|
/// Order is arbitrary.
|
||||||
/// Returns `None` when all rows have been read.
|
/// Returns `None` when all rows have been read.
|
||||||
pub fn recv(&self) -> Option<Result<String, DataReaderError>> {
|
pub fn recv(&self) -> Option<Result<String, DataReaderError>> {
|
||||||
self.rx.recv().ok()
|
self.rx.lock().recv().ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
//pub fn try_recv(&self) -> Result<Result<String, DataReaderError>, TryRecvError> {
|
//pub fn try_recv(&self) -> Result<Result<String, DataReaderError>, TryRecvError> {
|
||||||
|
|||||||
@@ -12,18 +12,16 @@ use tracing_subscriber::{
|
|||||||
// MARK: loglevel
|
// MARK: loglevel
|
||||||
//
|
//
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum, Default)]
|
||||||
#[derive(Default)]
|
|
||||||
pub enum LogLevel {
|
pub enum LogLevel {
|
||||||
Trace,
|
Trace,
|
||||||
Debug,
|
Debug,
|
||||||
#[default]
|
#[default]
|
||||||
Info,
|
Info,
|
||||||
Warn,
|
Warn,
|
||||||
Error,
|
Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Display for LogLevel {
|
impl Display for LogLevel {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
@@ -71,7 +69,7 @@ impl From<LoggingConfig> for EnvFilter {
|
|||||||
//
|
//
|
||||||
// Bins
|
// Bins
|
||||||
//
|
//
|
||||||
format!("nanochat_rs={}", conf.nanochat),
|
format!("llmfs={}", conf.nanochat),
|
||||||
conf.other.to_string(),
|
conf.other.to_string(),
|
||||||
]
|
]
|
||||||
.join(","),
|
.join(","),
|
||||||
@@ -216,16 +214,14 @@ pub enum LoggingTarget {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// How to print logs
|
/// How to print logs
|
||||||
#[derive(Debug, Clone, Copy, Deserialize)]
|
#[derive(Debug, Clone, Copy, Deserialize, Default)]
|
||||||
#[derive(Default)]
|
|
||||||
pub enum LoggingFormat {
|
pub enum LoggingFormat {
|
||||||
#[default]
|
#[default]
|
||||||
Ansi,
|
Ansi,
|
||||||
AnsiNoColor,
|
AnsiNoColor,
|
||||||
Json,
|
Json,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub struct LoggingInitializer {
|
pub struct LoggingInitializer {
|
||||||
/// Log filter for printed logs
|
/// Log filter for printed logs
|
||||||
pub preset: LogFilterPreset,
|
pub preset: LogFilterPreset,
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
#![recursion_limit = "256"]
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use indicatif::MultiProgress;
|
use indicatif::MultiProgress;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|||||||
Reference in New Issue
Block a user