Overview
This is a huge release with tons of improvements and new features.
Lots of work has been done in the autodiff system where gradient
checkpointing is now supported. It allows recomputing the forward pass
of some operations instead of saving their result. Not only can this
save a lot of memory usage during training, it also composes
gracefully with kernel fusion during the backward pass. This release
also introduces the new burn-jit project, which allows to create new
backends that can be compiled to any GPU shader language while
automatically supporting all our optimizations. We ported the WGPU
backend to this new representation, and new targets should be coming
soon. Stay tuned for the next releases. We also put a lot of care into
improving the user APIs. You don't need to implement both init and
init_with methods for optimized parameter initialization, since they
are now lazy. In addition, it's now easier to switch between backends
and precision types at runtime using the new backend bridge. Those
improvements were based on the community feedback, and we are
committed to continuously improving the APIs.
Core User APIs
A major change in this release is that most Burn types no longer
implement the Sync trait, such as modules, optimizers,
and tensors. This change should not impact users of the Learner struct for model training. However, it may affect those who implemented
their own training loop and inference server. While modules, optimizers
and tensors can be sent to other threads, they cannot be accessed concurrently
by multiple threads. This aligns with Burn's workflow, where each tensor
operation requires an owned version of the tensor. The change was made
to safely reduce the number of locks needed when modifying the state of
the autodiff graph, fusion state, allocation cache, and various other use
cases. While not all locks have been removed, the type signature no longer
poses a problem for follow-up optimizations. Note that the same tensor
can still be sent to multiple threads without copying the underlying data.
However it will require cloning before sending a tensor to a thread.
• Implementation of the above changes
#1575@nathanielsimard Tensor
• Support signed value for
Tensor::arange #1238@Nikaidou-Shinku
• Add
Tensor::unsqueeze_dims op
#1236@skewballfox
• Add support for
Any,
All operations
to Tensor
#1342@ashdtu
• Add
not_equal and
not_equal_elem tensor ops
#1374@laggui
• Element wise
min/
max between
a pair of tensors
#1385@boondocklabs
• Add
is_close and
all_close tensor
operators
#1389@antimora
• Interpolate tensor operation (Inference Only)
#1246@Nikaidou-Shinku@antimora@ashdtu
• Autodiff/training support for Nearest Interpolation
#1414@Nikaidou-Shinku@ashdtu@antimora
• Add
argwhere and
nonzero boolean
tensor ops
#1394@laggui
• Add
bool() op for numerical tensor
#1402@antimora
• Tensor
permute operator
#1410@antimora
• Add
sign tensor operator
#1446@antimora
• Rename
diagonal to
eye tensor
op and add missing entry for diagonal to Book tensor section
#1449@antimora
• Add
prod and
prod_dim tensor
ops
#1460@antimora
• Add
tril_mask,
triu_mask and
diag_mask ops
#1479@antimora
• Add
flip tensor operator
#1468@carrotflakes
• Add tensor
sorting operations
#1488#1494@laggui
• Add
topk tensor operation
#1497@laggui
• Tensor
expand operator
#1508@antimora
• Provide Tensor Padding Helpers
#960#1097@jcmullwh@antimora
• Move
log_sigmoid to activation ops
#1558@laggui
• Add
repeat autodiff and fusion support
#1600@louisfd Module
• Feature Addition: PRelu Module
#1328@Arjun31415
• Implement Instance Normalization
#1321@tushushu
• Add enum module support
#1337@laggui
• Make the parameters of
conv1d and
conv2d public
#1245@Arjun31415
• Parameters are now lazy initialized, so you don't need to implement
both the
init and
init_with(record) method for training/inference
#1539@nathanielsimard
• Support multilabel binary cross entropy
#1571
• Implement Huber loss
#1444@WorldSEnder
• Feat: Add Leaky Relu Model
#1467@Arjun31415
• Feat/swiglu
#1507@ashdtu
• Feat: transformer rotary positional encoding to transformer modules
#1604@ashdtu Optimizer
• Add linear learning rate scheduler
#1443@astral4
• Exponential learning rate scheduler
#1481@rubenjr0
• Cosine Annealing learning rate scheduler with cold restarts
#1481@rubenjr0
• Add Rank0 variant to AdaptorRecordV1 and AdaptorRecordItemV1
#1442@carrotflakes Train
• Add multi-label classification dataset and metric
#1572@laggui
• Add learner training summary
#1591@laggui Backend
This release also introduces the backend bridge, a new mechanism for
runtime switching between backends. While an improvement, it remains
compatible with previous methods of supporting mixed precision.
• Implementation of the backend bridge
#1529@nathanielsimard JIT
Significant effort has been devoted over the past few months to
refactor the previous Wgpu backend into a shader-agnostic Just-in-Time
backend. All lower-level dependencies have been abstracted into the
Just-in-Time Runtime trait, requiring a compiler, compute server, and
storage.
Wgpu
• Enable
burn-fusion by default
#1223@nathanielsimard
• Feature/autotune int ops
#1136@agelas
• Add runtime options in Wgpu init methods
#1505@nathanielsimard
• Decent speedup of transposed convolution
@louisfd Autodiff
Extensive work has also been undertaken on Burn's autodiff backend.
The backend now supports gradient checkpointing to reduce memory usage
and has been refactored into a client/server architecture. These
updates result in significantly less blocking when tracking gradients,
enhancing performance particularly on smaller models. Furthermore,
various bugs have been fixed where some graph nodes weren't used,
potentially truncating the autodiff graph. Overall, these changes make
the autodiff process more reliable and efficient.
• Improvements and refactoring of the autodiff backend
#1575#1358@louisfd@nathanielsimard Candle
• Upgrade to Candle 0.4.1
#1382@laggui Data
• Add an image folder dataset implementation
#1232#1132@laggui
• Add
burn::data::network::downloader #1283@laggui Import
• [PyTorchRecorder] Allow multiple pattern matches in chain
#1269@laggui
• [PyTorchRecorder] Pytorch config extraction
#1323@antimora
• [PyTorchRecorder] Pass top-level key to extract state_dict
#1300@antimora
• [PyTorchRecorder] print debug option
#1425@antimora
• [PyTorchRecorder] Truncate debug display for NestedValue
#1428@antimora
• [PyTorchRecorder] Support for non-contiguous indexes in
PyTorchFileRecorder keys
#1432@antimora
• [PyTorchRecorder] Add Enum module support
#1436@antimora
• [ONNX] Parser rewrite
#1296@skewballfox Benchmarks
• Created the
burnbench CLI
#1260@syl20bnr
• Added GitHub authentication to the
burnbench CLI
#1285@syl20bnr
• Updated GitHub App ID with the official application
#1397@syl20bnr
• Implemented benchmark upload functionality to the server
#1381@syl20bnr
• Compiled benchmarks in a dedicated target directory
#1435@syl20bnr
• Enhanced benchmark result presentation with a neat table and
attempted to run every benchmark
#1464@akhildevelops
• Improved access token refreshing and displayed authenticated user
name
#1483@syl20bnr
• Added system information to benchmark results
#1495@syl20bnr
• Included Operating System information in benchmark results
#1531@syl20bnr
• Fixed automatic fusion activation issue with Wgpu
#1542@syl20bnr
• Tweaked and added kinds to Gelu benchmark names
#1533@syl20bnr
• Ensured backend names in JSON reports match the
burnbench CLI
#1375@errordeveloper@syl20bnr
• Added 'all' choice to
--benches and
--backends options
#1567@syl20bnr
• Revamped
burnbench output for improved readability
and compactness
#1568@syl20bnr
• Added URL to browse results on the burn.dev website
#1573@syl20bnr Bug Fix
• Fix the pow backward pass when one of the tensor wasn't tracking the
gradients
#1225#1224@nathanielsimard
• Fix batch norm on the LibTorch backend when the aggregation was on
the same device
#1226@nathanielsimard
• Fix training dashboard metrics switch on Max OS & Linux
#1228@nathanielsimard
• Fix a bug introduced in (#1138) where arithmetic could fail on usize
type
#1287@louisfd
• [PyTorchRecorder] Fix out of memory bug
#1270#1286@antimora
• [PyTorchRecorder] Fix chain pattern matching when multiple patterns
are provided
#1273@laggui
• Fix LogEventStore end epoch log
#1314@laggui
• Huggingface dataset importer: check that pa_type is valid before
checking if is_binary
#1354@laggui
• Fix implicit casting of bool in wgpu backend
#1391@louisfd
• Fix Switched arguments in reshape_args_usize check
#1409@jackdarlison
• Fix tch view data corruption
#1434@nathanielsimard
• Missing Debug derive for Group Norm Config
#1482@Arjun31415
• Numerically stable log_sigmoid
#1548@laggui
• Fix pytorch recorder adapt_linear when using autodiff backend
#1576@laggui Infrastructure
• The minimum Rust version has been updated to 1.75
#1297@syl20bnr Docs
• Improve the doc feature flags for docs.rs
#1212@syl20bnr
• Include the backends in the documentation
#1229@nathanielsimard
• Update TORCH_CUDA_VERSION usage
#1284@laggui
• fix(book): add missing device parameter to mode.init()
#1302@apertureless
• fix(book): add missing second parameter to CrosEntropyLoss
constructor
#1301@apertureless
• docs(book-&-examples): modify book and examples with new prelude
module
#1372@bioinformatist
• Update tensor book
#1401@antimora
• Fix book MNIST reference (no more huggingface)
#1471@laggui
• Update SUPPORTED-ONNX-OPS.md
#1547@antimora
• Update book module
#1557@antimora
• Update pytorch-model.md
#1570@antimora
• Fixes to code examples in section 5.2
#1594@hrishim CI
• Add a semantic versioning checker
#1219@Luni-4
• Simplify CI binaries updating
#1235@Luni-4
• Trigger test suite when Cargo.lock file is updated
#1326@syl20bnr
• Fix codecov and update to weekly the .dependabot file for cargo
#1320@Luni-4
• Refactor xtask
#1288@iamricks
• Fix broken test and run-checks script
#1347@antimora
• Add stale action
#1383@Luni-4
• Update Cargo.lock workflow to trigger commit checks
#1399@syl20bnr
• Use GitHub's own action to generate GitHub App token
#1437@syl20bnr
• Add support for cargo metadata new workspace member format
#1500@syl20bnr
• Switch codecov to informational mode
#1540@syl20bnr
• Migrate workflows to use Blaze runners
#1596@dcvz
• Add a retry on adding ppa for kisak
#1599@dcvz Tests
• Add
NaN and
Inf detection in
assert_approx_eq to catch potential numerical bugs
#1209@skewballfox Misc
• Make all struct CamelCase
#1316#1311@antimora
• Move burn crates to their own crates directory
#1336@syl20bnr
• Add sub-crates as members of workspace
#1348@antimora
• Pytorch message updates
#1344@antimora
• Chore: update main README links to crate-specific READMEs
#1415@ekalosak
• [Wasm] remove exit in scripts
#1543@AlexErrant
• Use num-traits for float ops
#1584@antimora