基于NumPy与Redis分布式锁构建高性能的分布式滑动窗口计数器


单体应用里实现一个滑动窗口计数器相当直接,一个本地的 ConcurrentLinkedDeque 或类似的数据结构,配上一个定时清理任务就能解决问题。但在分布式环境下,这个模型瞬间崩塌。所有服务实例都需要一个统一的、精确的视图来判断某个行为在过去 N 秒内是否超出了阈值。

最初的方案自然而然地落在了 Redis 上。

V1: 基于 Sorted Set 的初步尝试

最符合直觉的方案是使用 Redis 的 Sorted Set (ZSET)。我们将每个请求事件的时间戳作为 score,一个唯一的请求ID(或简单的UUID)作为 member。

一个典型的流程如下:

  1. 记录事件: ZADD requests:{key} <timestamp> <request_id>
  2. 清理过期事件: ZREMRANGEBYSCORE requests:{key} -inf <current_timestamp - window_size>
  3. 获取窗口内计数: ZCARD requests:{key}

这个方案逻辑清晰,但并发问题很快就暴露出来。步骤2和3是一个典型的“读-改-写”序列,它不是原子的。在高并发下,两个线程可能同时执行:

  • 线程A:清理过期事件。
  • 线程B:清理过期事件(可能重复清理或基于A清理后的状态清理)。
  • 线程A:获取计数。
  • 线程B:在A之后添加新事件。
  • 线程B:获取计数。

这里的计数值会变得不可靠。

V2: 引入分布式锁保证原子性

为了解决原子性问题,引入分布式锁是标准操作。在 Spring Boot 项目中,集成 Redisson 是一个稳妥的选择。它提供了可重入锁、公平锁等多种实现。

我们的计数服务代码演变为这样:

// src/main/java/com/example/counter/SlidingWindowCounterV2.java
package com.example.counter;

import org.redisson.api.RLock;
import org.redisson.api.RScoredSortedSet;
import org.redisson.api.RedissonClient;
import org.springframework.stereotype.Service;

import java.util.UUID;
import java.util.concurrent.TimeUnit;

@Service
public class SlidingWindowCounterV2 {

    private final RedissonClient redissonClient;
    private static final String KEY_PREFIX = "sw_counter_zset:";

    public SlidingWindowCounterV2(RedissonClient redissonClient) {
        this.redissonClient = redissonClient;
    }

    /**
     * 记录一次事件并返回当前窗口内的总数
     * @param key 计数器的唯一标识
     * @param windowSizeInSeconds 窗口大小(秒)
     * @return 当前窗口内的事件数量
     */
    public long incrementAndGet(String key, int windowSizeInSeconds) {
        String redisKey = KEY_PREFIX + key;
        RLock lock = redissonClient.getLock("lock:" + redisKey);
        
        try {
            // 尝试在10秒内获取锁,锁的持有时间为60秒
            if (lock.tryLock(10, 60, TimeUnit.SECONDS)) {
                try {
                    long currentTimeMillis = System.currentTimeMillis();
                    long windowStartMillis = currentTimeMillis - (long) windowSizeInSeconds * 1000;
                    
                    RScoredSortedSet<String> sortedSet = redissonClient.getScoredSortedSet(redisKey);
                    
                    // 1. 清理过期事件
                    sortedSet.removeRangeByScoreAsync(-1, true, windowStartMillis, false);

                    // 2. 添加当前事件
                    // 使用UUID确保member唯一性
                    sortedSet.add(currentTimeMillis, UUID.randomUUID().toString());
                    
                    // 3. 设置一个合理的过期时间,防止冷数据永久占用内存
                    sortedSet.expireAsync(windowSizeInSeconds * 2, TimeUnit.SECONDS);

                    // 4. 获取当前计数值
                    return sortedSet.size();
                } finally {
                    if (lock.isHeldByCurrentThread()) {
                        lock.unlock();
                    }
                }
            } else {
                // 在真实项目中,这里应该有更完善的异常处理或重试逻辑
                throw new RuntimeException("Could not acquire lock for key: " + key);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException("Thread interrupted while acquiring lock", e);
        }
    }
}

配套的 Redisson 配置:

# src/main/resources/application.yml
spring:
  redis:
    host: localhost
    port: 6379

# Redisson configuration
redisson:
  singleServerConfig:
    address: "redis://127.0.0.1:6379"
    # 根据需要配置其他参数,如密码、数据库等
    # password: 
    # database: 0
// src/main/java/com/example/config/RedissonConfig.java
package com.example.config;

import org.redisson.Redisson;
import org.redisson.api.RedissonClient;
import org.redisson.config.Config;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class RedissonConfig {

    @Value("${redisson.singleServerConfig.address}")
    private String address;

    @Bean(destroyMethod = "shutdown")
    public RedissonClient redissonClient() {
        Config config = new Config();
        config.useSingleServer().setAddress(address);
        return Redisson.create(config);
    }
}

这个版本解决了原子性问题,但在性能压测下暴露了新的瓶颈。当窗口非常大(例如,统计过去一小时的请求)且请求频率极高时,ZSET 会变得非常庞大。每次操作都需要对一个巨大的 ZSET 进行清理和计数,即使在锁的保护下,这个操作本身也相当耗时。锁的持有时间变长,导致外部的线程大量阻塞,系统吞吐量急剧下降。

更重要的是,这个模型只能“计数”。如果业务需求变为“计算窗口内所有请求的平均延迟”或“窗口内请求大小的P99分位值”呢?ZSET 模型就无能为力了。我们不能把请求的元数据塞进 member 里,那样解析成本太高。

V3: 拥抱二进制与NumPy的高性能计算

问题的核心在于,我们在应用层和 Redis 之间来回传递了太多结构化的、对计算不友好的数据。Redis 擅长存储和检索,但不是一个计算引擎。而 JVM 做大规模数值计算,特别是统计计算,远不如 Python 的 NumPy 生态来得高效。

一个激进但高效的想法诞生了:将计算任务外包给一个专门的 Python 进程,并在 Redis 中使用对 NumPy 最友好的格式——原始字节数组——来存储数据。

这个架构的流程变为:

  1. Spring Boot 应用接收到请求。
  2. 获取分布式锁。
  3. 从 Redis 中获取两个 byte[]:一个存时间戳,一个存相关值(例如,请求延迟)。
  4. 将这两个 byte[] 传递给一个本地的 Python 脚本。
  5. Python 脚本使用 numpy.frombuffer 进行零拷贝加载,瞬间将字节数组转换为 NumPy array。
  6. 利用 NumPy 的向量化操作,极速完成数据清理和统计计算。
  7. 脚本将计算结果和修剪后的新 byte[] 返回。
  8. Spring Boot 应用将新的 byte[] 写回 Redis。
  9. 释放锁。
sequenceDiagram
    participant Client
    participant SpringBootApp as Spring Boot App
    participant Redisson
    participant PythonScript as Python (NumPy) Script
    participant Redis

    Client->>SpringBootApp: 发起请求 (e.g., record event)
    SpringBootApp->>Redisson: acquireLock("key")
    Redisson-->>SpringBootApp: Lock Acquired
    
    SpringBootApp->>Redis: HGETALL "sw_counter_bytes:key"
    Redis-->>SpringBootApp: 返回 timestamps_bytes 和 values_bytes
    
    SpringBootApp->>PythonScript: exec("counter.py", key, ts_bytes, val_bytes)
    
    PythonScript->>PythonScript: np.frombuffer(ts_bytes)
    PythonScript->>PythonScript: np.frombuffer(val_bytes)
    PythonScript->>PythonScript: mask = timestamps > window_start
    PythonScript->>PythonScript: new_timestamps = timestamps[mask]
    PythonScript->>PythonScript: new_values = values[mask]
    PythonScript->>PythonScript: result = np.sum(new_values)
    PythonScript->>PythonScript: new_ts_bytes = new_timestamps.tobytes()
    PythonScript->>PythonScript: new_val_bytes = new_values.tobytes()
    
    PythonScript-->>SpringBootApp: 返回 JSON {result, new_ts_bytes, new_val_bytes}
    
    SpringBootApp->>Redis: HSET "sw_counter_bytes:key" timestamps  values 
    Redis-->>SpringBootApp: OK
    
    SpringBootApp->>Redisson: releaseLock("key")
    Redisson-->>SpringBootApp: Lock Released
    
    SpringBootApp-->>Client: 返回计算结果

Python 计算脚本

这个脚本是整个方案的核心。它必须是无状态的、高效的,且只通过标准输入输出与外界通信。

# /scripts/counter_processor.py
import sys
import numpy as np
import time
import json
import base64

def process_window(timestamps_b64, values_b64, new_timestamp, new_value, window_size_seconds):
    """
    使用NumPy处理滑动窗口数据
    
    Args:
        timestamps_b64 (str): Base64编码的时间戳字节数组
        values_b64 (str): Base64编码的值字节数组
        new_timestamp (int): 新事件的时间戳 (毫秒)
        new_value (float): 新事件的值
        window_size_seconds (int): 窗口大小 (秒)
        
    Returns:
        str: JSON字符串,包含统计结果和新的字节数据
    """
    # 从Base64解码
    # 如果是空字符串,表示是第一次创建
    if timestamps_b64:
        ts_bytes = base64.b64decode(timestamps_b64)
        timestamps = np.frombuffer(ts_bytes, dtype=np.int64)
    else:
        timestamps = np.array([], dtype=np.int64)

    if values_b64:
        val_bytes = base64.b64decode(values_b64)
        values = np.frombuffer(val_bytes, dtype=np.float32)
    else:
        values = np.array([], dtype=np.float32)

    # 追加新数据
    timestamps = np.append(timestamps, new_timestamp)
    values = np.append(values, np.float32(new_value))

    # 计算窗口起始时间
    current_time_millis = int(time.time() * 1000)
    window_start_millis = current_time_millis - window_size_seconds * 1000

    # 核心:使用NumPy的向量化操作进行筛选,这比任何循环都快得多
    mask = timestamps >= window_start_millis
    
    # 应用掩码获取窗口内有效数据
    valid_timestamps = timestamps[mask]
    valid_values = values[mask]
    
    # 进行统计计算
    # 这里的计算可以非常复杂,例如 P99, std, etc.
    count = len(valid_values)
    total_sum = np.sum(valid_values)
    average = np.mean(valid_values) if count > 0 else 0.0
    p99 = np.percentile(valid_values, 99) if count > 0 else 0.0

    # 将更新后的数组转换回字节,准备存回Redis
    new_timestamps_bytes = valid_timestamps.tobytes()
    new_values_bytes = valid_values.tobytes()

    # 使用Base64编码以安全地通过stdout传输
    new_timestamps_b64 = base64.b64encode(new_timestamps_bytes).decode('utf-8')
    new_values_b64 = base64.b64encode(new_values_bytes).decode('utf-8')
    
    result = {
        "stats": {
            "count": int(count),
            "sum": float(total_sum),
            "average": float(average),
            "p99": float(p99)
        },
        "data": {
            "timestamps_b64": new_timestamps_b64,
            "values_b64": new_values_b64
        }
    }
    
    return json.dumps(result)

if __name__ == "__main__":
    # 从命令行参数读取输入
    # 格式: python counter_processor.py <ts_b64> <val_b64> <new_ts> <new_val> <window_sec>
    if len(sys.argv) != 6:
        sys.stderr.write("Usage: python counter_processor.py <ts_b64> <val_b64> <new_ts> <new_val> <window_sec>\n")
        sys.exit(1)

    _, ts_b64_arg, val_b64_arg, new_ts_arg, new_val_arg, window_sec_arg = sys.argv
    
    # 对于空的初始数据,传入特殊标记 "EMPTY"
    ts_b64_arg = "" if ts_b64_arg == "EMPTY" else ts_b64_arg
    val_b64_arg = "" if val_b64_arg == "EMPTY" else val_b64_arg

    output_json = process_window(
        ts_b64_arg,
        val_b64_arg,
        int(new_ts_arg),
        float(new_val_arg),
        int(window_sec_arg)
    )
    
    print(output_json)

Spring Boot 调用端实现

Java 这边需要负责调用这个 Python 脚本,并处理其输入输出。

// src/main/java/com/example/counter/SlidingWindowCounterV3.java
package com.example.counter;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.redisson.api.RLock;
import org.redisson.api.RedissonClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;

@Service
public class SlidingWindowCounterV3 {

    private static final Logger logger = LoggerFactory.getLogger(SlidingWindowCounterV3.class);
    private final RedissonClient redissonClient;
    private final ObjectMapper objectMapper;
    private static final String KEY_PREFIX = "sw_counter_bytes:";
    private static final String SCRIPT_PATH = "/scripts/counter_processor.py"; // 假设脚本在资源路径下

    public SlidingWindowCounterV3(RedissonClient redissonClient, ObjectMapper objectMapper) {
        this.redissonClient = redissonClient;
        this.objectMapper = objectMapper;
    }

    public Map<String, Double> addAndGetStats(String key, float value, int windowSizeInSeconds) {
        String redisKey = KEY_PREFIX + key;
        RLock lock = redissonClient.getLock("lock:" + redisKey);
        try {
            if (lock.tryLock(10, 60, TimeUnit.SECONDS)) {
                try {
                    // 1. 从Redis读取原始字节数据
                    Map<String, String> data = redissonClient.<String, String>getMap(redisKey).readAllMap();
                    String timestampsB64 = data.getOrDefault("timestamps", "EMPTY");
                    String valuesB64 = data.getOrDefault("values", "EMPTY");

                    // 2. 准备并执行Python脚本
                    ProcessBuilder pb = new ProcessBuilder(
                        "python3",
                        // 在生产环境中,你需要一个更鲁棒的方式来定位脚本
                        getClass().getResource(SCRIPT_PATH).getPath(),
                        timestampsB64,
                        valuesB64,
                        String.valueOf(System.currentTimeMillis()),
                        String.valueOf(value),
                        String.valueOf(windowSizeInSeconds)
                    );

                    Process process = pb.start();

                    // 3. 读取脚本的输出
                    StringBuilder output = new StringBuilder();
                    try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) {
                        String line;
                        while ((line = reader.readLine()) != null) {
                            output.append(line);
                        }
                    }

                    int exitCode = process.waitFor();
                    if (exitCode != 0) {
                        // 读取错误流
                        StringBuilder errorOutput = new StringBuilder();
                        try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream(), StandardCharsets.UTF_8))) {
                            String line;
                            while ((line = reader.readLine()) != null) {
                                errorOutput.append(line);
                            }
                        }
                        logger.error("Python script execution failed with exit code {}: {}", exitCode, errorOutput);
                        throw new RuntimeException("Python script execution failed.");
                    }

                    // 4. 解析结果并写回Redis
                    PythonResponse response = objectMapper.readValue(output.toString(), PythonResponse.class);
                    
                    Map<String, String> newData = new HashMap<>();
                    newData.put("timestamps", response.data.timestamps_b64);
                    newData.put("values", response.data.values_b64);
                    
                    redissonClient.getMap(redisKey).putAll(newData);
                    redissonClient.getMap(redisKey).expire(windowSizeInSeconds * 2, TimeUnit.SECONDS);

                    return response.stats;
                } finally {
                    if (lock.isHeldByCurrentThread()) {
                        lock.unlock();
                    }
                }
            } else {
                throw new RuntimeException("Could not acquire lock for key: " + key);
            }
        } catch (Exception e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException("Error processing sliding window", e);
        }
    }

    // 用于反序列化Python脚本输出的DTO
    private static class PythonResponse {
        public Map<String, Double> stats;
        public DataPayload data;
    }

    private static class DataPayload {
        public String timestamps_b64;
        public String values_b64;
    }
}

这个V3版本,虽然引入了跨语言调用的复杂性,但带来了巨大的收益:

  1. 极高的计算性能:NumPy 的向量化操作远胜于 Java 循环或 Redis Lua 脚本。对于包含数百万个数据点的窗口,性能差异可能是几个数量级。
  2. 强大的表达能力:我们不再局限于“计数”,而是可以轻松实现各种复杂的统计聚合,只需修改 Python 脚本即可,Java 端代码无需变动。
  3. 优化的网络IO和存储:原始字节的存储非常紧凑。我们一次性获取所有需要的数据,一次性写回,减少了与 Redis 的交互次数。
  4. 缩短锁持有时间:由于计算速度极快,整个事务(加锁->读Redis->计算->写Redis->解锁)的时间被显著缩短,大大提升了系统的并发能力。

局限性与未来优化路径

这个方案并非银弹。一个明显的代价是运维复杂性的增加,需要管理 Python 环境和依赖。Java 与 Python 进程间的通信开销也是一个不可忽视的性能损耗点。

未来的优化路径可以集中在以下几个方面:

  1. 通信优化:使用更高效的IPC机制,如gRPC或Unix Domain Sockets,替代基于进程启动和标准输入输出的简单模式,以降低通信延迟。
  2. 消除进程调用:探索在JVM中直接运行Python代码的方案,例如使用 GraalVM 的 Polyglot 能力。这能彻底消除跨进程开销,但会增加构建和部署的复杂性。
  3. 向Redis内部推进计算:最极致的优化是利用 RedisGears 或 Redis Functions (Redis 7.0+),将带有 NumPy 依赖的 Python 函数直接部署到 Redis 服务端执行。这样,数据无需离开 Redis 实例,网络开销降为零,性能将达到理论上的最优。但这需要 Redis 版本的支持以及对 Redis 模块开发的深入理解。

最终,这个V3架构是一个在现有普遍技术栈(Spring Boot, Python, Redis)下,通过巧妙的数据结构设计和计算任务卸载,实现高性能分布式计算的务实选择。它展示了在解决复杂工程问题时,跨技术栈组合有时能创造出远超单一技术栈能力的解决方案。


  目录