Chapter 5: training loop
This commit is contained in:
@@ -16,10 +16,10 @@ pub enum SubCommand {
|
||||
args: train_tokenizer::TrainTokenizerArgs,
|
||||
},
|
||||
|
||||
/// Sample data
|
||||
SampleData {
|
||||
/// Train model
|
||||
TrainModel {
|
||||
#[command(flatten)]
|
||||
args: sample_data::SampleDataArgs,
|
||||
args: sample_data::TrainModelArgs,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ impl SubCommand {
|
||||
match self {
|
||||
Self::Download { args } => args.run(mp),
|
||||
Self::TrainTokenizer { args } => args.run(mp),
|
||||
Self::SampleData { args } => args.run(mp),
|
||||
Self::TrainModel { args } => args.run(mp),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,221 +1,351 @@
|
||||
use ahash::AHasher;
|
||||
use anyhow::{Context, Result};
|
||||
use burn::{
|
||||
Tensor,
|
||||
backend::{Cuda, cuda::CudaDevice},
|
||||
module::{Module, Param, ParamId},
|
||||
backend::{Autodiff, Cuda, cuda::CudaDevice},
|
||||
config::Config,
|
||||
module::{AutodiffModule, Module, Param, ParamId},
|
||||
nn::{
|
||||
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
|
||||
loss::CrossEntropyLossConfig,
|
||||
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
|
||||
},
|
||||
prelude::Backend,
|
||||
optim::{AdamConfig, GradientsParams, Optimizer},
|
||||
prelude::{Backend, ToElement},
|
||||
tensor::{Bool, Distribution, Int, activation::softmax},
|
||||
};
|
||||
use burn_train::ClassificationOutput;
|
||||
use clap::Args;
|
||||
use indicatif::MultiProgress;
|
||||
use ndarray::Array2;
|
||||
use std::{f32, fs::File, path::PathBuf};
|
||||
use ndarray::{Array1, Array2};
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
f32,
|
||||
fs::File,
|
||||
hash::Hasher,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use tokenizer::Tokenizer;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::data_reader::DataReader;
|
||||
use crate::data_reader::{DataReader, DataReaderError};
|
||||
|
||||
// Text generation routine
|
||||
|
||||
/*
|
||||
{
|
||||
let init = "Initial context. This is ";
|
||||
let tokens = tokenizer.encode(&init);
|
||||
|
||||
let n_tokens = tokens.len();
|
||||
let input: Array1<u32> = Array1::from_vec(tokens);
|
||||
let mut input: Tensor<Cuda, 1, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||
.reshape([n_tokens]);
|
||||
|
||||
for _ in 0..100 {
|
||||
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
|
||||
println!("{:?}", tokens);
|
||||
println!("{}", tokenizer.decode(&tokens));
|
||||
|
||||
// Crop idx to context size;
|
||||
let batch = input
|
||||
.clone()
|
||||
.slice([0..config.context_size])
|
||||
.unsqueeze_dim(0);
|
||||
|
||||
// shape: [tokens, vocab_size]
|
||||
let logits = model.forward(batch).squeeze_dim::<2>(0);
|
||||
|
||||
// shape: [vocab_size]
|
||||
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
|
||||
|
||||
let probs = softmax(logits, 0); // shape: [n_tokens]
|
||||
let id_next = probs.argmax(0); // shape: [1]
|
||||
|
||||
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
struct TrainTestIterator<'a, B: Backend> {
|
||||
reader: DataReader,
|
||||
|
||||
ccfg: &'a ComputeConfig,
|
||||
mcfg: &'a GptModelConfig,
|
||||
tokenizer: &'a Tokenizer,
|
||||
eval: bool,
|
||||
device: &'a B::Device,
|
||||
|
||||
error: bool,
|
||||
|
||||
// Tokenized input/output pairs
|
||||
pairs: VecDeque<(Vec<u32>, u32)>,
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> TrainTestIterator<'a, B> {
|
||||
pub fn new(
|
||||
data_dir: impl AsRef<Path>,
|
||||
ccfg: &'a ComputeConfig,
|
||||
mcfg: &'a GptModelConfig,
|
||||
tokenizer: &'a Tokenizer,
|
||||
eval: bool,
|
||||
device: &'a B::Device,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let reader = DataReader::new(10, data_dir)?;
|
||||
|
||||
Ok(Self {
|
||||
reader,
|
||||
ccfg,
|
||||
mcfg,
|
||||
tokenizer,
|
||||
eval,
|
||||
device,
|
||||
|
||||
error: false,
|
||||
pairs: VecDeque::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
|
||||
type Item = Result<TrainBatch<B>, DataReaderError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.error {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut inputs = Vec::with_capacity(self.ccfg.batch_size);
|
||||
let mut targets = Vec::with_capacity(self.ccfg.batch_size);
|
||||
let stride = self.mcfg.context_size;
|
||||
|
||||
while inputs.len() < self.ccfg.batch_size {
|
||||
match self.pairs.pop_front() {
|
||||
Some((i, t)) => {
|
||||
// train/test split
|
||||
{
|
||||
let mut hasher = AHasher::default();
|
||||
hasher.write(self.ccfg.eval_salt.as_bytes());
|
||||
|
||||
// Don't care about endianness, ahash output is unstable anyway
|
||||
hasher.write(unsafe { std::mem::transmute(&i[..]) });
|
||||
hasher.write_u32(t);
|
||||
|
||||
let test = // is this point in the test set?
|
||||
hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64();
|
||||
|
||||
if test ^ self.eval {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
inputs.push(i);
|
||||
targets.push(t);
|
||||
}
|
||||
|
||||
None => {
|
||||
let text = match self.reader.next() {
|
||||
None => break,
|
||||
Some(Ok(x)) => x,
|
||||
Some(Err(x)) => {
|
||||
self.error = true;
|
||||
return Some(Err(x));
|
||||
}
|
||||
};
|
||||
|
||||
let emb = self.tokenizer.encode(&text);
|
||||
|
||||
// Skip small texts
|
||||
//
|
||||
// TODO: do this better
|
||||
// TODO: maybe using <|bos|>?
|
||||
// TODO: non-uniform batches?
|
||||
if emb.len() < self.mcfg.context_size {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pairs = emb
|
||||
.windows(self.mcfg.context_size + 1)
|
||||
.step_by(stride)
|
||||
.map(|x| {
|
||||
(
|
||||
x[..self.mcfg.context_size].to_vec(),
|
||||
x[self.mcfg.context_size],
|
||||
)
|
||||
});
|
||||
|
||||
self.pairs.extend(pairs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if inputs.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let shape = [inputs.len(), self.mcfg.context_size];
|
||||
|
||||
// Arrange data in memory
|
||||
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
|
||||
let targets: Array1<u32> = Array1::from_vec(targets);
|
||||
|
||||
// Create tensors on gpu
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let inputs =
|
||||
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
|
||||
|
||||
return Some(Ok(TrainBatch { inputs, targets }));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
|
||||
pub struct SampleDataArgs {
|
||||
pub struct TrainModelArgs {
|
||||
/// Path to training data
|
||||
#[clap(long, default_value = "data")]
|
||||
data_dir: PathBuf,
|
||||
data: PathBuf,
|
||||
|
||||
/// Path to tokenizer
|
||||
#[clap(long)]
|
||||
tokenizer: PathBuf,
|
||||
|
||||
/// How many texts to return
|
||||
#[clap(long, short = 'n', default_value = "10")]
|
||||
n: usize,
|
||||
|
||||
/// How many texts to skip
|
||||
#[clap(long, short = 's', default_value = "0")]
|
||||
skip: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
/// Number of tokens
|
||||
pub vocab_size: u32,
|
||||
|
||||
/// Maximum number of input tokens with positional embeddings
|
||||
pub context_size: usize,
|
||||
|
||||
/// Dimension of each token's embedding
|
||||
pub embed_dim: usize,
|
||||
|
||||
/// Number of attention heads
|
||||
pub n_heads: usize,
|
||||
|
||||
/// Dimension of each attn head
|
||||
pub head_dim: usize,
|
||||
|
||||
/// Number of transformer blocks
|
||||
pub n_layers: usize,
|
||||
|
||||
pub embed_drop_rate: f64,
|
||||
pub attention_drop_rate: f64,
|
||||
pub shortcut_drop_rate: f64,
|
||||
pub struct ComputeConfig {
|
||||
pub batch_size: usize,
|
||||
pub eval_frac: f64,
|
||||
pub eval_salt: String,
|
||||
}
|
||||
|
||||
impl SampleDataArgs {
|
||||
impl TrainModelArgs {
|
||||
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
|
||||
let device = CudaDevice::new(0);
|
||||
|
||||
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
|
||||
//let device = WgpuDevice::DiscreteGpu(0);
|
||||
|
||||
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
|
||||
let tokenizer: Tokenizer =
|
||||
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
|
||||
|
||||
let config = Config {
|
||||
let ccfg = ComputeConfig {
|
||||
batch_size: 10,
|
||||
eval_frac: 0.1,
|
||||
eval_salt: "salt".into(),
|
||||
};
|
||||
|
||||
let mcfg = GptModelConfig {
|
||||
vocab_size: tokenizer.vocab_size(),
|
||||
context_size: 4,
|
||||
context_size: 256,
|
||||
embed_dim: 768,
|
||||
n_heads: 12,
|
||||
head_dim: 64, // = 768 / 12
|
||||
n_layers: 12,
|
||||
n_layers: 1,
|
||||
embed_drop_rate: 0.1,
|
||||
attention_drop_rate: 0.1,
|
||||
shortcut_drop_rate: 0.1,
|
||||
};
|
||||
|
||||
let stride = config.context_size;
|
||||
|
||||
let batch_size = 10;
|
||||
let mut input_batch = Vec::with_capacity(batch_size);
|
||||
let mut output_batch = Vec::with_capacity(batch_size);
|
||||
|
||||
#[expect(clippy::unwrap_used)] // Lazy error handling
|
||||
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
|
||||
|
||||
let model = GptModel::new(&config, &device);
|
||||
|
||||
// Text generation routine
|
||||
let mut model: GptModel<Autodiff<Cuda>> = mcfg.init(&device);
|
||||
|
||||
/*
|
||||
{
|
||||
let init = "Initial context. This is ";
|
||||
let tokens = tokenizer.encode(&init);
|
||||
let loader_train = DataLoaderBuilder::new(batcher.clone())
|
||||
.batch_size(ccfg.batch_size)
|
||||
//.shuffle(config.seed)
|
||||
.num_workers(5)
|
||||
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
|
||||
|
||||
let n_tokens = tokens.len();
|
||||
let input: Array1<u32> = Array1::from_vec(tokens);
|
||||
let mut input: Tensor<Cuda, 1, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||
.reshape([n_tokens]);
|
||||
let loader_test = DataLoaderBuilder::new(batcher)
|
||||
.batch_size(ccfg.batch_size)
|
||||
//.shuffle(config.seed)
|
||||
.num_workers(5)
|
||||
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
|
||||
|
||||
for _ in 0..100 {
|
||||
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
|
||||
println!("{:?}", tokens);
|
||||
println!("{}", tokenizer.decode(&tokens));
|
||||
let learner = LearnerBuilder::new("./tmp")
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
|
||||
.num_epochs(10)
|
||||
.summary()
|
||||
.build(model, AdamConfig::new().init(), 1e-4);
|
||||
|
||||
// Crop idx to context size;
|
||||
let batch = input
|
||||
.clone()
|
||||
.slice([0..config.context_size])
|
||||
.unsqueeze_dim(0);
|
||||
|
||||
// shape: [tokens, vocab_size]
|
||||
let logits = model.forward(batch).squeeze_dim::<2>(0);
|
||||
|
||||
// shape: [vocab_size]
|
||||
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
|
||||
|
||||
let probs = softmax(logits, 0); // shape: [n_tokens]
|
||||
let id_next = probs.argmax(0); // shape: [1]
|
||||
|
||||
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
|
||||
}
|
||||
}
|
||||
learner.fit(loader_train, loader_test);
|
||||
*/
|
||||
|
||||
for i in iter {
|
||||
let tokens = tokenizer.encode(&i);
|
||||
// Initialize optimizer
|
||||
let mut optim = AdamConfig::new().init();
|
||||
let learning_rate = 1e-4;
|
||||
|
||||
// Skip small texts.
|
||||
// TODO: do this better
|
||||
// TODO: non-uniform batches?
|
||||
if tokens.len() < config.context_size {
|
||||
continue;
|
||||
for epoch in 0..10 {
|
||||
debug!("Running epoch {epoch}");
|
||||
|
||||
// Training phase
|
||||
let mut train_loss_sum = 0.0;
|
||||
let mut train_total = 0;
|
||||
|
||||
for batch in
|
||||
TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, false, &device)
|
||||
.context("while initializing reader")?
|
||||
{
|
||||
let batch = batch.context("while reading batch")?;
|
||||
|
||||
// Forward pass with gradients
|
||||
let output = model.forward_train(batch.inputs, batch.targets);
|
||||
|
||||
train_total += output.targets.dims()[0] as i32;
|
||||
train_loss_sum += output.loss.clone().into_scalar().to_f32();
|
||||
|
||||
let grads = output.loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
model = optim.step(learning_rate, model, grads);
|
||||
}
|
||||
|
||||
for (a, b) in tokens.windows(config.context_size).step_by(stride).zip(
|
||||
tokens[stride..]
|
||||
.windows(config.context_size)
|
||||
.step_by(stride),
|
||||
) {
|
||||
input_batch.push(a.to_owned());
|
||||
output_batch.push(b.to_owned());
|
||||
let mut valid_loss_sum = 0.0;
|
||||
let mut valid_total = 0;
|
||||
|
||||
/*
|
||||
let context = a;
|
||||
let desired = &b[b.len() - 1..];
|
||||
println!("{context:?} -> {desired:?}");
|
||||
let mut n_eval = 0;
|
||||
debug!("Evaluating batches");
|
||||
|
||||
let input = tokenizer.decode(context);
|
||||
let target = tokenizer.decode(desired);
|
||||
println!("{input:?} -> {target:?}");
|
||||
*/
|
||||
for batch in TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, true, &device)
|
||||
.context("while initializing reader")?
|
||||
{
|
||||
let batch = batch.context("while reading batch")?;
|
||||
n_eval += batch.targets.shape()[0];
|
||||
|
||||
if input_batch.len() >= batch_size {
|
||||
let shape = [input_batch.len(), config.context_size];
|
||||
// Forward pass without gradients
|
||||
let output = model.valid().forward_train(batch.inputs, batch.targets);
|
||||
|
||||
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
|
||||
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let input: Tensor<Cuda, 2, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||
.reshape(shape);
|
||||
|
||||
let output =
|
||||
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
|
||||
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let output: Tensor<Cuda, 2, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
|
||||
.reshape(shape);
|
||||
|
||||
self.batch(&config, input, &model);
|
||||
}
|
||||
valid_total += output.targets.dims()[0] as i32;
|
||||
valid_loss_sum += output.loss.into_scalar().to_f32();
|
||||
}
|
||||
}
|
||||
|
||||
if !input_batch.is_empty() {
|
||||
let shape = [input_batch.len(), config.context_size];
|
||||
// Compute and log epoch results
|
||||
let train_loss = if train_total > 0 {
|
||||
train_loss_sum / train_total as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let valid_loss = if valid_total > 0 {
|
||||
valid_loss_sum / valid_total as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
|
||||
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let input: Tensor<Cuda, 2, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape);
|
||||
|
||||
let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
|
||||
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let output: Tensor<Cuda, 2, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape);
|
||||
|
||||
self.batch(&config, input, &model);
|
||||
info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn batch(&self, _cfg: &Config, input: Tensor<Cuda, 2, Int>, model: &GptModel<Cuda>) {
|
||||
let logits = model.forward(input);
|
||||
println!("{:?}", logits.shape());
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// MARK: model
|
||||
//
|
||||
|
||||
/// Multihead attention.
|
||||
///
|
||||
/// Equivalent to many stacked CausalAttention layers.
|
||||
@@ -315,7 +445,7 @@ impl<B: Backend> MultiheadAttention<B> {
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, total_dim].into(),
|
||||
[total_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
dropout: Dropout { prob: dropout },
|
||||
@@ -389,6 +519,7 @@ impl<B: Backend> MultiheadAttention<B> {
|
||||
let mask = self
|
||||
.utri_mask
|
||||
.clone()
|
||||
.slice([0..tokens, 0..tokens])
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.unsqueeze_dim::<4>(0)
|
||||
.expand(attn_scores.shape());
|
||||
@@ -422,35 +553,50 @@ impl<B: Backend> MultiheadAttention<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GptModel<B: Backend> {
|
||||
embedder_tok: Embedding<B>,
|
||||
embedder_pos: Embedding<B>,
|
||||
embedder_drop: Dropout,
|
||||
#[derive(Config, Debug)]
|
||||
pub struct GptModelConfig {
|
||||
/// Number of tokens
|
||||
pub vocab_size: u32,
|
||||
|
||||
trf_blocks: Vec<TransformerBlock<B>>,
|
||||
final_norm: LayerNorm<B>,
|
||||
out_head: Param<Tensor<B, 2>>,
|
||||
/// Maximum number of input tokens with positional embeddings
|
||||
pub context_size: usize,
|
||||
|
||||
/// Dimension of each token's embedding
|
||||
pub embed_dim: usize,
|
||||
|
||||
/// Number of attention heads
|
||||
pub n_heads: usize,
|
||||
|
||||
/// Dimension of each attn head
|
||||
pub head_dim: usize,
|
||||
|
||||
/// Number of transformer blocks
|
||||
pub n_layers: usize,
|
||||
|
||||
pub embed_drop_rate: f64,
|
||||
pub attention_drop_rate: f64,
|
||||
pub shortcut_drop_rate: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> GptModel<B> {
|
||||
pub fn new(cfg: &Config, device: &B::Device) -> Self {
|
||||
let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize];
|
||||
impl GptModelConfig {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
|
||||
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
|
||||
|
||||
Self {
|
||||
embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device),
|
||||
GptModel {
|
||||
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
|
||||
.init(device),
|
||||
|
||||
embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device),
|
||||
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
|
||||
|
||||
embedder_drop: Dropout {
|
||||
prob: cfg.embed_drop_rate,
|
||||
prob: self.embed_drop_rate,
|
||||
},
|
||||
|
||||
trf_blocks: (0..cfg.n_layers)
|
||||
.map(|_| TransformerBlock::new(cfg, device))
|
||||
trf_blocks: (0..self.n_layers)
|
||||
.map(|_| TransformerBlock::new(&self, device))
|
||||
.collect(),
|
||||
|
||||
final_norm: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
|
||||
|
||||
out_head: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
@@ -464,13 +610,34 @@ impl<B: Backend> GptModel<B> {
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainBatch<B: Backend> {
|
||||
pub inputs: Tensor<B, 2, Int>,
|
||||
|
||||
/// Correct next token for each input
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GptModel<B: Backend> {
|
||||
embedder_tok: Embedding<B>,
|
||||
embedder_pos: Embedding<B>,
|
||||
embedder_drop: Dropout,
|
||||
|
||||
trf_blocks: Vec<TransformerBlock<B>>,
|
||||
final_norm: LayerNorm<B>,
|
||||
out_head: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> GptModel<B> {
|
||||
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let n_tokens = input.shape()[1];
|
||||
|
||||
let embed_tok = self.embedder_tok.forward(input.clone());
|
||||
let embed_pos = self
|
||||
.embedder_tok
|
||||
.embedder_pos
|
||||
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
|
||||
|
||||
let x = embed_tok + embed_pos;
|
||||
@@ -481,6 +648,29 @@ impl<B: Backend> GptModel<B> {
|
||||
|
||||
return logits;
|
||||
}
|
||||
|
||||
pub fn forward_train(
|
||||
&self,
|
||||
inputs: Tensor<B, 2, Int>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
) -> ClassificationOutput<B> {
|
||||
// shape: [batch, n_tokens, n_vocabulary]
|
||||
let output = self.forward(inputs);
|
||||
|
||||
// Get last token
|
||||
// shape: [batch, n_vocabulary]
|
||||
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
|
||||
|
||||
let loss = CrossEntropyLossConfig::new()
|
||||
.init(&targets.device())
|
||||
.forward(output.clone(), targets.clone());
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
output,
|
||||
targets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
@@ -498,7 +688,7 @@ pub struct TransformerBlock<B: Backend> {
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerBlock<B> {
|
||||
pub fn new(cfg: &Config, device: &B::Device) -> Self {
|
||||
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
attention: MultiheadAttention::new(
|
||||
cfg.embed_dim,
|
||||
|
||||
Reference in New Issue
Block a user