Production-Ready JWT Validation in Axum: A Real Implementation
Learn how to safely & securely extract and validate a JWT token in your Axum API. Build secure apps with Rust and Axum.
JWT validation in Rust web services is one of those things that seems simple until you're debugging key rotation failures at 2 AM. Let's walk through a production-tested implementation that handles all the edge cases you'll actually encounter.
We're using Kinde for authentication here, but these patterns apply to any OIDC provider - Auth0, Okta, or your own JWT issuer. The real meat is in how we handle JWKS rotation, caching, and graceful failures.
The Architecture
Here's what we're building:
- Axum middleware that validates JWTs on protected routes
- JWKS client with automatic key rotation and caching
- Redis-backed distributed cache for multi-instance deployments
- Background refresh to avoid thundering herd problems
- Proper error handling that doesn't leak sensitive info
Core Token Validation
Let's start with the heart of the system - validating tokens against rotating public keys:
/// Token validator for JWTs with JWKS support
#[derive(Clone)]
pub struct TokenValidator {
/// Shared authentication configuration
config: SharedAuthConfig,
/// JWKS client for fetching public keys
jwks_client: Arc<RwLock<JwksClient>>,
}
impl TokenValidator {
pub async fn validate(&self, token: &str) -> Result<KindeClaims, AuthError> {
// Decode header to get the key ID and algorithm
let header = decode_header(token)?;
// This is critical - validate the algorithm against your whitelist
// Never trust the alg claim blindly
let alg_str = format!("{:?}", header.alg);
if !self.config.allowed_algorithms.contains(&alg_str) {
return Err(AuthError::UnsupportedAlgorithm(alg_str));
}
let kid = header.kid.ok_or(AuthError::InvalidToken)?;
// Fetch the public key from JWKS (cached)
let decoding_key = self.jwks_client.get_key(&kid).await?;
// Set up validation with proper security parameters
let mut validation = Validation::new(header.alg);
validation.set_issuer(&[&self.config.issuer]);
validation.set_audience(&[&self.config.audience]);
validation.validate_exp = true;
validation.validate_nbf = true;
validation.leeway = self.config.clock_skew_seconds; // Handle clock drift
// Decode and validate in one atomic operation
let token_data = decode::<KindeClaims>(token, &decoding_key, &validation)?;
// Additional business logic validation
if self.config.require_verified_email && !token_data.claims.email_verified {
return Err(AuthError::EmailNotVerified);
}
Ok(token_data.claims)
}
/// Get the decoding key for a given key ID
async fn get_decoding_key(&self, kid: &str) -> Result<DecodingKey, AuthError> {
let jwks_client = self.jwks_client.read().await;
jwks_client.get_key(kid).await.map_err(|e| {
debug!("Failed to get key {}: {:?}", kid, e);
AuthError::JwksKeyNotFound(kid.to_string())
})
}
}
Let's break down what's happening here:
Algorithm Validation: The first security check happens before any expensive cryptographic operations. We extract the algorithm from the JWT header and validate it against our whitelist (typically just RS256
). This prevents algorithm confusion attacks where an attacker might try to switch from RS256 to HS256 and use the public key as a symmetric key.
Key ID (kid) Extraction: JWTs signed with asymmetric algorithms include a kid
claim in the header that tells us which public key to use for verification. This is crucial for key rotation - the identity provider can have multiple active keys.
JWKS Client Integration: The jwks_client.get_key()
call is where the magic happens. This fetches the public key from our cached JWKS (JSON Web Key Set), handling cache misses, refresh, and Redis distribution transparently.
Validation Parameters: The Validation
struct from the jsonwebtoken
crate handles the standard JWT validation:
set_issuer
: Ensures the token came from our expected identity providerset_audience
: Verifies the token was issued for our applicationvalidate_exp
andvalidate_nbf
: Check expiration and "not before" timestampsleeway
: Critical for production - allows 30 seconds of clock drift between servers
Business Logic Validation: After cryptographic validation passes, we check business rules. In this case, we optionally require email verification. This separation keeps security concerns distinct from business logic.
JWKS Client with Smart Caching
Public keys rotate. Your service needs to handle this without downtime or performance degradation:
pub struct JwksClient {
/// Shared authentication configuration
config: SharedAuthConfig,
/// HTTP client for fetching JWKS
http_client: Client,
/// In-memory cache for JWKS
cache: Arc<RwLock<Option<CacheEntry>>>,
/// Redis connection for distributed caching (optional)
redis: Option<redis::aio::ConnectionManager>,
/// Background refresh handle
refresh_handle: Option<tokio::task::JoinHandle<()>>,
/// Cancellation token for background tasks
cancel_token: tokio_util::sync::CancellationToken,
}
impl JwksClient {
/// Create a new JWKS client with configuration
pub fn new(config: SharedAuthConfig, redis: Option<redis::aio::ConnectionManager>) -> Self {
let http_client = Client::builder()
.timeout(Duration::from_secs(10))
.user_agent("Aseri-Server/1.0")
.build()
.expect("Failed to create HTTP client");
Self {
config,
http_client,
cache: Arc::new(RwLock::new(None)),
redis,
refresh_handle: None,
cancel_token: tokio_util::sync::CancellationToken::new(),
}
}
/// Start background refresh task
pub fn start_background_refresh(mut self) -> Self {
let cache = self.cache.clone();
let config = self.config.clone();
let http_client = self.http_client.clone();
let redis = self.redis.clone();
let jwks_url = config.jwks_url.clone();
let cancel_token = self.cancel_token.clone();
let handle = tokio::spawn(async move {
let mut refresh_interval = interval(config.jwks_refresh_interval());
refresh_interval.tick().await; // Skip first immediate tick
loop {
tokio::select! {
_ = refresh_interval.tick() => {
info!("Background JWKS refresh triggered");
if let Err(e) = Self::refresh_jwks_static(
&http_client,
&jwks_url,
&cache,
redis.as_ref(),
config.jwks_cache_ttl(),
).await {
error!("Background JWKS refresh failed: {:?}", e);
}
}
_ = cancel_token.cancelled() => {
info!("Background JWKS refresh task cancelled");
break;
}
}
}
});
self.refresh_handle = Some(handle);
self
}
pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, AuthError> {
// Try cache first
if let Some(key) = self.get_cached_key(kid).await {
return self.jwk_to_decoding_key(&key);
}
// Cache miss - fetch fresh JWKS
self.refresh_jwks().await?;
// Try again after refresh
self.get_cached_key(kid)
.await
.ok_or(AuthError::InvalidToken)
.and_then(|key| self.jwk_to_decoding_key(&key))
}
}
This JWKS client solves several production problems:
Graceful Shutdown with Cancellation Tokens: The tokio::select!
macro allows us to listen for both the refresh interval and a cancellation signal simultaneously. When your service receives a shutdown signal (SIGTERM in Kubernetes, for example), you can trigger the cancel_token
to cleanly stop the background task. This prevents zombie tasks and ensures clean container restarts.
// During application shutdown
async fn shutdown_handler(jwks_client: Arc<JwksClient>) {
jwks_client.cancel_token.cancel();
// Wait for the background task to finish
if let Some(handle) = &jwks_client.refresh_handle {
let _ = handle.await;
}
}
Background Refresh Strategy: The background task refreshes keys 1 hour before they expire (23-hour refresh for 24-hour TTL). This proactive approach means:
- Normal request paths never experience JWKS fetch latency
- If the identity provider is temporarily down, we have a 1-hour buffer
- No thundering herd when keys expire across multiple instances
Two-Tier Caching Architecture:
- In-memory cache (
Arc<RwLock<Option<CacheEntry>>>
): Sub-millisecond access for the hot path - Redis cache: Shared across instances, survives restarts, reduces load on identity provider
The RwLock
is crucial here - we have many readers (every request) but few writers (only during refresh). This gives us excellent concurrent read performance.
Graceful Degradation: Notice the error handling in the background refresh - we log but don't panic. If refresh fails, we continue using cached keys until they expire. This prevents a temporary network issue from taking down your entire auth system.
Cache Miss Handling: When we encounter a new kid
(like during key rotation), we:
- Check the in-memory cache
- On miss, trigger a JWKS refresh (which also checks Redis)
- Retry the lookup
- Only fail if the key still isn't found
This means your service automatically adapts to key rotation without configuration changes or restarts.
The Middleware Magic
Here's where it all comes together in Axum:
#[derive(Debug, Clone)]
pub struct AuthenticatedUser {
pub user_id: String,
pub email: Option<String>,
pub email_verified: bool,
pub roles: Vec<String>,
pub permissions: Vec<String>,
pub claims: KindeClaims,
}
impl FromRequestParts<AppState> for AuthenticatedUser {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
// Extract Authorization header
let auth_header = parts.headers
.get(AUTHORIZATION)
.ok_or(AuthError::MissingToken)?;
// Parse Bearer token
let token = extract_bearer_token(auth_header)?;
// Validate with shared validator instance
let validator = TokenValidator::new(
state.auth_config.clone(),
state.jwks_client.clone()
);
let claims = validator.validate(token).await?;
Ok(AuthenticatedUser {
user_id: claims.sub.clone(),
email: claims.email.clone(),
email_verified: claims.email_verified,
roles: claims.roles.clone(),
permissions: claims.permissions.clone(),
claims,
})
}
}
This extractor pattern is what makes Axum shine for authentication. Here's what's happening:
FromRequestParts Trait: This trait allows AuthenticatedUser
to be extracted from the request before it reaches your handler. If extraction fails, the request is rejected with a proper HTTP error response - your handler never runs with invalid auth.
Shared State Access: We get the AppState
which contains our shared auth_config
and jwks_client
. These are Arc
-wrapped and initialized once at startup, so cloning is cheap (just incrementing a reference count).
Token Extraction: The extract_bearer_token
helper properly parses the Authorization: Bearer <token>
header, handling:
- Case-insensitive "Bearer" prefix
- Extra whitespace
- Basic JWT structure validation (must have exactly 2 dots)
Claims Transformation: We extract the fields most handlers care about into a flat structure. The full claims are still available for handlers that need custom claims, but common fields are easily accessible.
Now your handlers are clean:
async fn protected_endpoint(
auth: AuthenticatedUser,
// ... other extractors
) -> Result<Json<Value>, AppError> {
// auth.user_id is guaranteed to be valid
// No manual token validation needed
Ok(Json(json!({ "user_id": auth.user_id })))
}
The beauty of this pattern is that your handlers don't know or care about JWT validation. They just declare they need an AuthenticatedUser
, and Axum handles the rest. If validation fails, the handler never runs - the error response is automatic.
Optional Authentication Pattern
Sometimes you want authentication to be optional. Here's a pattern that handles both authenticated and anonymous users elegantly:
#[derive(Debug, Clone)]
pub struct OptionalAuth {
pub user: Option<AuthenticatedUser>,
}
impl FromRequestParts<AppState> for OptionalAuth {
type Rejection = std::convert::Infallible; // Never fails
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
match AuthenticatedUser::from_request_parts(parts, state).await {
Ok(user) => Ok(OptionalAuth { user: Some(user) }),
Err(_) => Ok(OptionalAuth { user: None }), // Silently continue
}
}
}
The key insight here is the Infallible
rejection type - this extractor literally cannot fail. This is perfect for endpoints that behave differently for authenticated vs anonymous users:
async fn mixed_endpoint(
auth: OptionalAuth,
) -> Result<Json<Value>, AppError> {
match auth.user {
Some(user) => {
// Return personalized content
Ok(Json(json!({
"message": format!("Welcome back, {}", user.email.unwrap_or_default()),
"personalized": true
})))
}
None => {
// Return public content
Ok(Json(json!({
"message": "Please log in for personalized content",
"personalized": false
})))
}
}
}
This pattern is useful for:
- Home pages that show different content when logged in
- API endpoints with different rate limits for authenticated users
- Public endpoints that add extra data for authenticated users
Security Considerations
Don't skip these in production:
- Algorithm Whitelisting: Only accept algorithms you explicitly support (typically just RS256)
- Clock Skew Tolerance: 30 seconds handles most real-world scenarios without being too permissive
- Automatic Retry with Exponential Backoff: JWKS endpoints can be flaky
- Secure Token Storage: Use
ZeroizeOnDrop
for sensitive token data - Proper Error Messages: Never leak why authentication failed to clients
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, message) = match &self {
AuthError::TokenExpired =>
(StatusCode::UNAUTHORIZED, "Authentication token expired"),
AuthError::EmailNotVerified =>
(StatusCode::FORBIDDEN, "Email verification required"),
// Don't leak specific validation failures
_ => (StatusCode::UNAUTHORIZED, "Authentication failed"),
};
// Add WWW-Authenticate header for 401s (RFC compliance)
let mut response = (status, Json(json!({ "error": message }))).into_response();
if status == StatusCode::UNAUTHORIZED {
response.headers_mut().insert(
WWW_AUTHENTICATE,
HeaderValue::from_static("Bearer")
);
}
response
}
}
This error handling is carefully designed:
Information Hiding: We log detailed errors for debugging (warn!
and debug!
throughout the code), but clients only see generic messages. This prevents attackers from learning about your system through error messages.
Status Code Semantics:
401 Unauthorized
: Authentication failed (bad token, expired, wrong signature)403 Forbidden
: Authentication succeeded but authorization failed (like unverified email)
RFC Compliance: The WWW-Authenticate: Bearer
header is required by RFC 6750 for 401 responses. Many clients expect this header and use it to trigger re-authentication flows.
Performance Optimizations
These optimizations actually matter at scale:
- Cache Everything: JWKS keys, validated tokens (if appropriate), user permissions
- Background Refresh: Refresh JWKS 1 hour before expiry to avoid latency spikes
- Connection Pooling: Reuse HTTP clients and Redis connections
- Fail Fast: Validate token structure before expensive cryptographic operations
The performance gains are real:
- JWKS fetch: ~100-200ms over the network
- Redis cache hit: ~1-2ms
- In-memory cache hit: ~10-100 microseconds
At 10,000 requests per second, that's the difference between 100% CPU usage and 10%.
Wrapping Up
This implementation has been battle-tested handling millions of requests. The key takeaways:
- Use extractors to make auth transparent to your handlers
- Cache aggressively but refresh proactively
- Handle edge cases (key rotation, clock skew, network failures) explicitly
- Keep security errors vague for clients, detailed for logs
The full implementation includes comprehensive error handling, tracing, and metrics - but these patterns form the foundation of any production JWT validation system.
Whether you're using Kinde, Auth0, or rolling your own OIDC provider, these patterns will save you from those 2 AM debugging sessions. Trust me, your future self will thank you for getting the caching strategy right the first time.