Batching#

Aka how can I use biosym for deep learning applications?

from tracemalloc import start
import numpy as np
import matplotlib.pyplot as plt
import time

from biosym.model.model import load_model
from biosym.utils import states
import jax

Load 2D Gait Model#

We’ll load a more complex 2D gait model that includes ground contact forces and actuator models. This demonstrates BiosymModel’s capability to handle sophisticated biomechanical systems.

model_file = os.path.join(current_dir, "tests", "models", "gait2d_torque", "gait2d_torque.yaml")
print("Loading 2D gait model with torque actuators...")
start_time = time.time()
model = load_model(model_file, force_rebuild=True)
load_time = time.time() - start_time

print(f"Model loaded in {load_time:.3f} seconds")
print(f"Model has {model.n_states} states and {model.n_constants} constants")
Loading 2D gait model with torque actuators...
Replacing dynamic symbols in the EOM with the v_ states, this might take a while...
Lambdifying the EOM took 0.8248860836029053 seconds
Precompiling the Jacobian took 1.0173466205596924 seconds
Precompiling the confun took 1.4790878295898438 seconds
Precompiling the mass matrix took 0.8816077709197998 seconds
Precompiling the forcing took 0.6875641345977783 seconds
Model loaded in 9.181 seconds
Model has 45 states and 94 constants

Create batches of movement data

# Initialize state vector (positions, velocities, accelerations, forces, etc.)
states_dict_0 = model.default_inputs
print(states_dict_0)

# Create a batch of 1000 identical state vectors
batch_size = 1000
states_ = states.stack_dataclasses([states_dict_0] * batch_size)
print(states_)

# For any function in the model, you can now pass in the batched states using jax.vmap
# e.g. here compute the output of the dynamics (constraint) function
# The input axes are defined as (0, None) meaning the first argument (states) is batched
# while the second argument (constants) is not batched
dynamics_fn = jax.vmap(model.run["confun"], in_axes=(0, None))
dynamics_output = dynamics_fn(states_.states, states_.constants)
print("Dynamics output shape with batching:", dynamics_output.shape)
StatesDict:
        States:
                model: (45,)
                gc_model: (0,)
                actuator_model: (6,)
                h: None
        Constants:
                model: (94,)
                gc_model: (0,)
                actuator_model: (0,)

StatesDict:
        States:
                model: (1000, 45)
                gc_model: (1000, 0)
                actuator_model: (1000, 6)
                h: None
        Constants:
                model: (94,)
                gc_model: (0,)
                actuator_model: (0,)

Dynamics output shape with batching: (1000, 9, 1)

Performance of batching (optional)#

Check if jax finds a GPU

print("Available devices:", jax.devices())
start_time = time.time()
dynamics_output = dynamics_fn(states_.states, states_.constants)
end_time = time.time()
print(f"Computed dynamics for batch of size {batch_size} in {end_time - start_time:.4f} seconds")

start_time = time.time()
for i in range(batch_size):
    dynamics_output_single = model.run["confun"](states_.states[i], states_.constants)
end_time = time.time()
print(f"Computed dynamics for batch of size {batch_size} without batching in {end_time - start_time:.4f} seconds")
Available devices: [CpuDevice(id=0)]
Computed dynamics for batch of size 1000 in 0.0007 seconds
Computed dynamics for batch of size 1000 without batching in 0.8610 seconds

Total running time of the script: (0 minutes 11.955 seconds)

Gallery generated by Sphinx-Gallery