keras
FrameworkFreeMulti-backend Keras
Capabilities15 decomposed
multi-backend neural network computation with unified api
Medium confidenceProvides a single high-level API for defining models and layers that transparently dispatches numerical computation to JAX, TensorFlow, PyTorch, or OpenVINO backends selected at import time via KERAS_BACKEND environment variable or ~/.keras/keras.json. The framework maintains a backend-agnostic source of truth in keras/src/ with generated public API surface in keras/api/, enabling seamless backend switching without code changes. Runtime dispatch follows two paths: symbolic execution during model construction (shape/dtype inference via compute_output_spec on KerasTensor objects) and eager execution during training/inference (forwarded to active backend implementation).
Implements true multi-backend abstraction through keras/src/ source-of-truth architecture with auto-generated keras/api/ public surface, enabling compile-time API consistency across backends while maintaining separate backend-specific implementations in keras/src/backend/{jax,torch,tensorflow,openvino}/ directories. Uses symbolic execution path (compute_output_spec) for shape inference and eager path for actual computation, avoiding backend lock-in.
Unlike TensorFlow (TF-only) or PyTorch (PyTorch-only), Keras 3 provides true write-once-run-anywhere semantics with equal support for JAX, TensorFlow, and PyTorch through a unified API rather than framework-specific wrappers.
backend-agnostic layer and operation definitions
Medium confidenceDefines neural network layers (Dense, Conv2D, LSTM, etc.) and operations (numpy-compatible ops, neural network ops, core backend ops) in keras/src/ that are completely decoupled from backend implementation. Each layer inherits from a base Layer class that implements compute_output_spec() for symbolic shape/dtype inference and call() for eager execution. Backend-specific implementations are injected at runtime through the active backend module, allowing the same layer code to execute on JAX, TensorFlow, PyTorch, or OpenVINO without modification.
Implements layers as backend-agnostic Python classes with dual-path execution: symbolic path uses compute_output_spec() to infer output shapes/dtypes without computation, eager path delegates to backend-specific implementations via keras.ops.* namespace. Layer definitions in keras/src/layers/ contain zero backend-specific code; all dispatch happens through the ops module.
Compared to PyTorch (backend-specific) or TensorFlow (TF-centric), Keras layers achieve true backend independence by separating layer logic from backend implementation, allowing identical layer code to run on JAX, PyTorch, or TensorFlow without conditional logic.
callback system for training monitoring and control
Medium confidenceProvides a callback system (keras/src/callbacks/) that enables monitoring and controlling training through hooks at various training stages: on_epoch_begin, on_epoch_end, on_batch_begin, on_batch_end, on_train_begin, on_train_end. Built-in callbacks include EarlyStopping (stop training when validation metric plateaus), ModelCheckpoint (save best model), ReduceLROnPlateau (reduce learning rate), TensorBoard (visualization), and CSVLogger (log metrics). Callbacks are executed synchronously during training and have access to training state (epoch, batch, metrics, model weights).
Implements callback system in keras/src/callbacks/ with hooks at multiple training stages (epoch/batch begin/end) and built-in callbacks for common use cases (EarlyStopping, ModelCheckpoint, ReduceLROnPlateau). Callbacks are executed synchronously during training with access to training state, enabling monitoring and control without modifying training loop code.
Unlike PyTorch (no built-in callback system) or TensorFlow (callbacks are TensorFlow-specific), Keras provides a unified callback system across all backends with built-in callbacks for common use cases like early stopping and model checkpointing.
metric computation and tracking during training
Medium confidenceProvides a metric system (keras/src/metrics/) for computing and tracking statistics during training and evaluation. Metrics are stateful objects that accumulate values across batches and compute aggregate statistics (accuracy, AUC, precision, recall, etc.). Metrics are compiled into models via model.compile(metrics=[...]) and automatically computed during training/evaluation. The framework provides built-in metrics for classification, regression, and ranking tasks. Metrics support both eager and graph execution modes and work identically across all backends.
Implements metrics as stateful objects in keras/src/metrics/ that accumulate values across batches and compute aggregate statistics. Metrics are compiled into models and automatically computed during training/evaluation, with support for both eager and graph execution modes across all backends.
Unlike PyTorch (requires manual metric computation) or TensorFlow (metrics are TensorFlow-specific), Keras provides a unified metric system across all backends with built-in metrics for common use cases and automatic computation during training.
optimizer implementations with learning rate scheduling
Medium confidenceProvides optimizer implementations (keras/src/optimizers/) including SGD, Adam, RMSprop, and others that update model weights based on gradients. Optimizers are backend-agnostic and delegate gradient updates to backend-specific implementations. Learning rate scheduling is supported through LearningRateSchedule objects that adjust learning rate during training based on epoch or batch number. Optimizers support momentum, weight decay, gradient clipping, and other advanced features. All optimizers work identically across backends.
Implements optimizers as backend-agnostic objects in keras/src/optimizers/ that delegate gradient updates to backend-specific implementations. Learning rate scheduling is supported through LearningRateSchedule objects that adjust learning rate during training, with all optimizers working identically across backends.
Unlike PyTorch (requires manual learning rate scheduling) or TensorFlow (optimizers are TensorFlow-specific), Keras provides a unified optimizer system across all backends with built-in learning rate scheduling and advanced features like gradient clipping and weight decay.
loss function computation and gradient backpropagation
Medium confidenceProvides loss functions (keras/src/losses/) for training objectives including classification losses (categorical_crossentropy, sparse_categorical_crossentropy), regression losses (mean_squared_error, mean_absolute_error), and ranking losses. Loss functions are compiled into models via model.compile(loss=...) and automatically computed during training. The framework automatically computes gradients with respect to loss using the active backend's autodiff system (JAX's jax.grad, PyTorch's autograd, TensorFlow's GradientTape). Loss computation and gradient backpropagation are handled transparently without user code.
Implements loss functions as backend-agnostic objects in keras/src/losses/ with automatic gradient computation through the active backend's autodiff system. Loss computation and backpropagation are handled transparently during training without user code, leveraging JAX's jax.grad, PyTorch's autograd, or TensorFlow's GradientTape.
Unlike PyTorch (requires manual loss computation and backpropagation) or TensorFlow (loss functions are TensorFlow-specific), Keras provides a unified loss system across all backends with automatic gradient computation and built-in loss functions for common use cases.
model introspection and weight access
Medium confidenceProvides APIs for inspecting model structure and accessing weights: model.summary() displays layer structure and parameter counts, model.get_weights() returns all weights as NumPy arrays, model.set_weights() updates weights, model.get_config() returns model configuration as JSON, model.get_layer() retrieves specific layers by name. These APIs work identically across all backends and enable model analysis, weight manipulation, and configuration serialization without backend-specific code.
Implements model introspection APIs in keras/src/models/model.py that work identically across all backends, providing access to model structure, weights, and configuration without backend-specific code. Weight access converts from backend-native tensors to NumPy arrays, enabling framework-agnostic weight manipulation.
Unlike PyTorch (requires framework-specific APIs like state_dict()) or TensorFlow (requires TensorFlow-specific APIs), Keras provides unified introspection APIs across all backends with automatic conversion to NumPy for framework-agnostic weight access.
numpy-compatible operation api with backend dispatch
Medium confidenceExposes a NumPy-compatible operation API (keras.ops.numpy.*) that mirrors NumPy's function signatures and behavior while dispatching to backend-specific implementations. Operations include array manipulation (reshape, concatenate, transpose), mathematical functions (sin, exp, matmul), and linear algebra (linalg.solve, linalg.eigh). The dispatch mechanism routes each operation call to the active backend's implementation in keras/src/backend/{backend}/numpy.py, ensuring numerical consistency across backends while leveraging backend-specific optimizations.
Implements NumPy API compatibility layer that maps NumPy function signatures to backend-specific implementations without requiring users to learn backend APIs. Each operation in keras/ops/numpy/ delegates to backend-specific versions in keras/src/backend/{jax,torch,tensorflow,openvino}/numpy.py, maintaining API consistency while preserving backend optimizations.
Unlike raw JAX/PyTorch/TensorFlow APIs (which require learning framework-specific syntax), Keras ops.numpy provides familiar NumPy semantics across all backends; unlike NumPy itself, it supports automatic differentiation and GPU acceleration through any backend.
neural network operation primitives with automatic differentiation
Medium confidenceProvides a comprehensive set of neural network operations (keras.ops.nn.*) including activations (relu, sigmoid, softmax), normalization (batch_norm, layer_norm), convolution, pooling, and attention mechanisms. These operations are implemented in keras/src/ops/nn.py with backend-specific implementations in keras/src/backend/{backend}/nn.py. Each operation supports automatic differentiation through the active backend's autodiff system (JAX's jax.grad, PyTorch's autograd, TensorFlow's GradientTape), enabling gradient computation for training without explicit implementation.
Implements neural network operations as backend-agnostic functions that delegate to backend-specific implementations while preserving autodiff semantics. Each operation in keras/ops/nn/ has corresponding implementations in keras/src/backend/{jax,torch,tensorflow,openvino}/nn.py, ensuring gradients flow correctly through the active backend's autodiff system without user intervention.
Unlike framework-specific APIs (PyTorch's torch.nn.functional, TensorFlow's tf.nn), Keras nn ops provide identical semantics across backends while automatically leveraging each backend's autodiff system; unlike NumPy, these operations are differentiable by default.
model training loop with distributed training support
Medium confidenceProvides a high-level training API (model.fit(), model.train_on_batch()) that abstracts away backend-specific training mechanics. The training loop handles gradient computation, optimizer updates, metric tracking, and callback execution in a backend-agnostic manner. Distributed training is supported through backend-specific mechanisms: JAX uses jax.experimental.multihost_utils, PyTorch uses torch.distributed, TensorFlow uses tf.distribute.Strategy. The framework automatically detects available devices and distributes computation across them without requiring user code changes.
Implements a backend-agnostic training loop in keras/src/trainers/ that delegates distributed training to backend-specific mechanisms (JAX's multihost utils, PyTorch's torch.distributed, TensorFlow's tf.distribute) while maintaining identical user-facing API. Gradient computation is handled through each backend's autodiff system without explicit user code.
Unlike PyTorch (requires manual training loops) or TensorFlow (requires tf.distribute.Strategy knowledge), Keras provides a unified fit() API that automatically handles distributed training across backends with minimal configuration.
model serialization and export to multiple formats
Medium confidenceProvides model serialization to multiple deployment formats: SavedModel (TensorFlow format), ONNX (framework-agnostic), LiteRT (mobile), and OpenVINO (edge inference). Export is handled through keras/src/saving/ with format-specific implementations. SavedModel export preserves the full model graph and weights; ONNX export converts the model to ONNX intermediate representation for cross-framework compatibility; LiteRT export optimizes for mobile devices; OpenVINO export targets edge devices. Each format supports different deployment scenarios and optimization levels.
Implements multi-format export through keras/src/saving/ with separate export pipelines for SavedModel, ONNX, LiteRT, and OpenVINO. Each format has its own conversion logic that translates the backend-agnostic model representation to format-specific structures, enabling deployment across diverse platforms without backend-specific code.
Unlike single-format exporters (TensorFlow's SavedModel, PyTorch's ONNX export), Keras provides unified export API supporting SavedModel, ONNX, LiteRT, and OpenVINO from the same model code, enabling flexible deployment across cloud, mobile, and edge platforms.
quantization and model compression
Medium confidenceProvides quantization capabilities (keras/src/quantization/) for reducing model size and inference latency through reduced precision (int8, float16). Quantization is applied through quantization policies that specify precision for weights, activations, and computations. The framework supports post-training quantization and quantization-aware training (QAT). Quantization is implemented in a backend-agnostic manner, with backend-specific optimizations for each framework (JAX uses jax.numpy operations, PyTorch uses torch.quantization, TensorFlow uses tf.quantization).
Implements quantization as a backend-agnostic policy system in keras/src/quantization/ that applies precision reduction through DType policies. Quantization is applied uniformly across backends while leveraging backend-specific optimizations (JAX's jit compilation, PyTorch's quantization kernels, TensorFlow's quantization ops).
Unlike framework-specific quantization (PyTorch's torch.quantization, TensorFlow's tf.quantization), Keras quantization works identically across all backends through a unified policy system; unlike post-hoc quantization tools, Keras supports quantization-aware training for better accuracy.
dtype policies for mixed-precision training and inference
Medium confidenceProvides DType (data type) policies that specify precision for layer computations, weights, and outputs. Policies enable mixed-precision training where computations use float32 for numerical stability but weights are stored in float16 for memory efficiency. DType policies are defined in keras/src/layers/layer.py and applied through layer.dtype_policy. The framework automatically handles precision conversion during forward/backward passes, leveraging backend-specific mixed-precision support (JAX's automatic mixed precision, PyTorch's autocast, TensorFlow's mixed_float16 policy).
Implements DType policies as a layer-level configuration in keras/src/layers/layer.py that specifies computation and storage precision. Policies are applied uniformly across backends while leveraging backend-specific mixed-precision support (JAX's automatic mixed precision, PyTorch's autocast, TensorFlow's mixed_float16).
Unlike framework-specific mixed-precision APIs (PyTorch's autocast, TensorFlow's mixed_float16), Keras DType policies provide a unified interface across backends; unlike manual precision casting, policies automatically handle precision conversion during forward/backward passes.
functional and sequential model apis for rapid prototyping
Medium confidenceProvides two high-level model definition APIs: Sequential API for simple linear stacks of layers, and Functional API for complex architectures with multiple inputs/outputs and skip connections. Both APIs are defined in keras/src/models/ and compile to the same underlying Model class. Sequential API uses list-based layer stacking; Functional API uses symbolic tensor composition where layer calls return KerasTensor objects that represent computation graph structure. Both APIs support the same training, evaluation, and export capabilities.
Implements Sequential and Functional APIs as separate model definition patterns in keras/src/models/ that both compile to the same underlying Model class. Sequential uses list-based layer composition; Functional uses symbolic tensor composition with KerasTensor objects representing the computation graph structure without eager execution.
Unlike PyTorch (requires Subclassing API for all models) or TensorFlow (separate Keras Sequential/Functional APIs), Keras provides both Sequential and Functional APIs as first-class citizens with identical training/export capabilities, enabling rapid prototyping without sacrificing model complexity.
custom layer and model subclassing for advanced architectures
Medium confidenceProvides a Subclassing API where developers inherit from keras.layers.Layer or keras.Model to implement custom layers and models with arbitrary Python logic. Subclassed layers override build() to create weights and call() to define forward computation. The framework automatically tracks weights, handles gradient computation, and manages state. Subclassing enables dynamic control flow (if/while statements), custom gradient computation (via @tf.custom_gradient), and arbitrary Python logic that cannot be expressed through Sequential or Functional APIs.
Implements Subclassing API through keras.layers.Layer and keras.Model base classes that automatically track weights, handle gradient computation, and manage state through Python inheritance. Subclassed layers override build() and call() methods, enabling arbitrary Python logic while maintaining compatibility with the training loop and autodiff system.
Unlike Functional API (static computation graphs), Subclassing API enables dynamic control flow and custom gradient computation; unlike raw backend APIs (PyTorch's nn.Module, TensorFlow's tf.Module), Keras Subclassing API provides automatic weight tracking and gradient computation across all backends.
Capabilities are decomposed by AI analysis. Each maps to specific user intents and improves with match feedback.
Related Artifactssharing capabilities
Artifacts that share capabilities with keras, ranked by overlap. Discovered automatically through the match graph.
Keras
High-level deep learning API — multi-backend (JAX, TensorFlow, PyTorch), simple model building.
Keras 3
Multi-backend deep learning API for JAX, TF, and PyTorch.
FastAI
High-level deep learning with built-in best practices.
guidance
A guidance language for controlling large language models.
tensorflow
TensorFlow is an open source machine learning framework for everyone.
lm-evaluation-harness
EleutherAI's evaluation framework — 200+ benchmarks, powers Open LLM Leaderboard.
Best For
- ✓research teams evaluating multiple frameworks
- ✓production teams needing framework flexibility
- ✓developers building framework-agnostic ML libraries
- ✓framework developers extending Keras with custom layers
- ✓researchers implementing novel architectures
- ✓teams building backend-agnostic ML libraries on top of Keras
- ✓practitioners monitoring training progress
- ✓teams implementing early stopping and model checkpointing
Known Limitations
- ⚠Backend must be selected at import time and cannot be changed within a single Python session
- ⚠OpenVINO backend supports inference only, not training
- ⚠Backend-specific optimizations and features may not be fully exposed through the unified API
- ⚠Performance overhead from abstraction layer adds latency compared to native framework usage
- ⚠Custom layers must use only Keras ops or backend-agnostic operations; direct backend API calls break portability
- ⚠Layer implementations cannot access backend-specific optimizations or features not exposed through Keras ops
Requirements
Input / Output
UnfragileRank
UnfragileRank is computed from adoption signals, documentation quality, ecosystem connectivity, match graph feedback, and freshness. No artifact can pay for a higher rank.
Repository Details
Package Details
About
Multi-backend Keras
Categories
Alternatives to keras
Are you the builder of keras?
Claim this artifact to get a verified badge, access match analytics, see which intents users search for, and manage your listing.
Get the weekly brief
New tools, rising stars, and what's actually worth your time. No spam.
Data Sources
Looking for something else?
Search →