
Elegant & Performant Scientific Machine Learning in Julia
479 Stars
Updated Last
1 Month Ago
Started In
March 2022

A Pure Julia Deep Learning Framework designed for Scientific Machine Learning

๐Ÿ’ป Installation

import Pkg


If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.

๐Ÿคธ Quickstart

using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support

# Seeding
rng = Xoshiro(0)

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
              Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Get the device determined by Lux
device = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device

# Dummy Input
x = rand(rng, Float32, 128, 2) |> device

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
gs = only(gradient(p -> sum(first(Lux.apply(model, x, p, st))), ps))

# Optimization
st_opt = Optimisers.setup(Optimisers.Adam(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)

๐Ÿ“š Examples

Look in the examples directory for self-contained usage examples. The documentation has examples sorted into proper categories.

๐Ÿ†˜ Getting Help

For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use github issues or even better send in a pull request.

๐Ÿง‘โ€๐Ÿ”ฌ Citation

If you found this library to be useful in academic work, then please cite:

Also consider starring our github repo.

๐Ÿง‘โ€๐Ÿ’ป Contributing

This section is somewhat incomplete. You can contribute by contributing to finishing this section ๐Ÿ˜œ.

๐Ÿงช Testing

The full test of Lux.jl takes a long time, here's how to test a portion of the code.

For each @testitem, there are corresponding tags, for example:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers]

For example, let's consider the tests for SkipConnection:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers] begin

We can test the group to which SkipConnection belongs by testing core_layers. To do so set the LUX_TEST_GROUP environment variable, or rename the tag to further narrow the test scope:

export LUX_TEST_GROUP="core_layers"

Or directly modify the default test tag in runtests.jl:

# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "core_layers"))

But be sure to restore the default value "all" before submitting the code.

Furthermore if you want to run a specific test based on the name of the testset, you can use TestEnv.jl as follows. Start with activating the Lux environment and then run the following:

using TestEnv; TestEnv.activate(); using ReTestItems;

# Assuming you are in the main directory of Lux
ReTestItems.runtests("tests/"; name = "NAME OF THE TEST")

For the SkipConnection tests that would be:

ReTestItems.runtests("tests/"; name = "SkipConnection")

