.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/batching.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_batching.py: Batching ============================= Aka how can I use biosym for deep learning applications? .. GENERATED FROM PYTHON SOURCE LINES 7-17 .. code-block:: Python from tracemalloc import start import numpy as np import matplotlib.pyplot as plt import time from biosym.model.model import load_model from biosym.utils import states import jax .. GENERATED FROM PYTHON SOURCE LINES 56-62 Load 2D Gait Model ------------------------- We'll load a more complex 2D gait model that includes ground contact forces and actuator models. This demonstrates BiosymModel's capability to handle sophisticated biomechanical systems. .. GENERATED FROM PYTHON SOURCE LINES 62-73 .. code-block:: Python model_file = os.path.join(current_dir, "tests", "models", "gait2d_torque", "gait2d_torque.yaml") print("Loading 2D gait model with torque actuators...") start_time = time.time() model = load_model(model_file, force_rebuild=True) load_time = time.time() - start_time print(f"Model loaded in {load_time:.3f} seconds") print(f"Model has {model.n_states} states and {model.n_constants} constants") .. rst-class:: sphx-glr-script-out .. code-block:: none Loading 2D gait model with torque actuators... Replacing dynamic symbols in the EOM with the v_ states, this might take a while... Lambdifying the EOM took 0.8248860836029053 seconds Precompiling the Jacobian took 1.0173466205596924 seconds Precompiling the confun took 1.4790878295898438 seconds Precompiling the mass matrix took 0.8816077709197998 seconds Precompiling the forcing took 0.6875641345977783 seconds Model loaded in 9.181 seconds Model has 45 states and 94 constants .. GENERATED FROM PYTHON SOURCE LINES 74-75 Create batches of movement data .. GENERATED FROM PYTHON SOURCE LINES 75-93 .. code-block:: Python # Initialize state vector (positions, velocities, accelerations, forces, etc.) states_dict_0 = model.default_inputs print(states_dict_0) # Create a batch of 1000 identical state vectors batch_size = 1000 states_ = states.stack_dataclasses([states_dict_0] * batch_size) print(states_) # For any function in the model, you can now pass in the batched states using jax.vmap # e.g. here compute the output of the dynamics (constraint) function # The input axes are defined as (0, None) meaning the first argument (states) is batched # while the second argument (constants) is not batched dynamics_fn = jax.vmap(model.run["confun"], in_axes=(0, None)) dynamics_output = dynamics_fn(states_.states, states_.constants) print("Dynamics output shape with batching:", dynamics_output.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none StatesDict: States: model: (45,) gc_model: (0,) actuator_model: (6,) h: None Constants: model: (94,) gc_model: (0,) actuator_model: (0,) StatesDict: States: model: (1000, 45) gc_model: (1000, 0) actuator_model: (1000, 6) h: None Constants: model: (94,) gc_model: (0,) actuator_model: (0,) Dynamics output shape with batching: (1000, 9, 1) .. GENERATED FROM PYTHON SOURCE LINES 94-97 Performance of batching (optional) ------------------------- Check if jax finds a GPU .. GENERATED FROM PYTHON SOURCE LINES 97-109 .. code-block:: Python print("Available devices:", jax.devices()) start_time = time.time() dynamics_output = dynamics_fn(states_.states, states_.constants) end_time = time.time() print(f"Computed dynamics for batch of size {batch_size} in {end_time - start_time:.4f} seconds") start_time = time.time() for i in range(batch_size): dynamics_output_single = model.run["confun"](states_.states[i], states_.constants) end_time = time.time() print(f"Computed dynamics for batch of size {batch_size} without batching in {end_time - start_time:.4f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Available devices: [CpuDevice(id=0)] Computed dynamics for batch of size 1000 in 0.0007 seconds Computed dynamics for batch of size 1000 without batching in 0.8610 seconds .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 11.955 seconds) .. _sphx_glr_download_auto_examples_batching.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: batching.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: batching.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: batching.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_