# spring-boot-demo-ratelimit-redis
> 此 demo 主要演示了 Spring Boot 项目如何通过 AOP 结合 Redis + Lua 脚本实现分布式限流,旨在保护 API 被恶意频繁访问的问题,是 `spring-boot-demo-ratelimit-guava` 的升级版。
## 1. 主要代码
### 1.1. pom.xml
```xml
4.0.0
spring-boot-demo-ratelimit-redis
1.0.0-SNAPSHOT
jar
spring-boot-demo-ratelimit-redis
Demo project for Spring Boot
com.xkcoding
spring-boot-demo
1.0.0-SNAPSHOT
UTF-8
UTF-8
1.8
org.springframework.boot
spring-boot-starter-web
org.springframework.boot
spring-boot-starter-aop
org.springframework.boot
spring-boot-starter-data-redis
org.apache.commons
commons-pool2
cn.hutool
hutool-all
org.springframework.boot
spring-boot-starter-test
test
org.projectlombok
lombok
true
spring-boot-demo-ratelimit-redis
org.springframework.boot
spring-boot-maven-plugin
```
### 1.2. 限流注解
```java
/**
*
* 限流注解,添加了 {@link AliasFor} 必须通过 {@link AnnotationUtils} 获取,才会生效
*
*
* @author yangkai.shen
* @date Created in 2019-09-30 10:31
* @see AnnotationUtils
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
long DEFAULT_REQUEST = 10;
/**
* max 最大请求数
*/
@AliasFor("max") long value() default DEFAULT_REQUEST;
/**
* max 最大请求数
*/
@AliasFor("value") long max() default DEFAULT_REQUEST;
/**
* 限流key
*/
String key() default "";
/**
* 超时时长,默认1分钟
*/
long timeout() default 1;
/**
* 超时时间单位,默认 分钟
*/
TimeUnit timeUnit() default TimeUnit.MINUTES;
}
```
### 1.3. AOP处理限流
```java
/**
*
* 限流切面
*
*
* @author yangkai.shen
* @date Created in 2019-09-30 10:30
*/
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor(onConstructor_ = @Autowired)
public class RateLimiterAspect {
private final static String SEPARATOR = ":";
private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
private final StringRedisTemplate stringRedisTemplate;
private final RedisScript limitRedisScript;
@Pointcut("@annotation(com.xkcoding.ratelimit.redis.annotation.RateLimiter)")
public void rateLimit() {
}
@Around("rateLimit()")
public Object pointcut(ProceedingJoinPoint point) throws Throwable {
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
// 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解
RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
if (rateLimiter != null) {
String key = rateLimiter.key();
// 默认用类名+方法名做限流的 key 前缀
if (StrUtil.isBlank(key)) {
key = method.getDeclaringClass().getName()+StrUtil.DOT+method.getName();
}
// 最终限流的 key 为 前缀 + IP地址
// TODO: 此时需要考虑局域网多用户访问的情况,因此 key 后续需要加上方法参数更加合理
key = key + SEPARATOR + IpUtil.getIpAddr();
long max = rateLimiter.max();
long timeout = rateLimiter.timeout();
TimeUnit timeUnit = rateLimiter.timeUnit();
boolean limited = shouldLimited(key, max, timeout, timeUnit);
if (limited) {
throw new RuntimeException("手速太快了,慢点儿吧~");
}
}
return point.proceed();
}
private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
// 最终的 key 格式为:
// limit:自定义key:IP
// limit:类名.方法名:IP
key = REDIS_LIMIT_KEY_PREFIX + key;
// 统一使用单位毫秒
long ttl = timeUnit.toMillis(timeout);
// 当前时间毫秒数
long now = Instant.now().toEpochMilli();
long expired = now - ttl;
// 注意这里必须转为 String,否则会报错 java.lang.Long cannot be cast to java.lang.String
Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + "");
if (executeTimes != null) {
if (executeTimes == 0) {
log.error("【{}】在单位时间 {} 毫秒内已达到访问上限,当前接口上限 {}", key, ttl, max);
return true;
} else {
log.info("【{}】在单位时间 {} 毫秒内访问 {} 次", key, ttl, executeTimes);
return false;
}
}
return false;
}
}
```
### 1.4. lua 脚本
```lua
-- 下标从 1 开始
local key = KEYS[1]
local now = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local expired = tonumber(ARGV[3])
-- 最大访问量
local max = tonumber(ARGV[4])
-- 清除过期的数据
-- 移除指定分数区间内的所有元素,expired 即已经过期的 score
-- 根据当前时间毫秒数 - 超时毫秒数,得到过期时间 expired
redis.call('zremrangebyscore', key, 0, expired)
-- 获取 zset 中的当前元素个数
local current = tonumber(redis.call('zcard', key))
local next = current + 1
if next > max then
-- 达到限流大小 返回 0
return 0;
else
-- 往 zset 中添加一个值、得分均为当前时间戳的元素,[value,score]
redis.call("zadd", key, now, now)
-- 每次访问均重新设置 zset 的过期时间,单位毫秒
redis.call("pexpire", key, ttl)
return next
end
```
### 1.5. 接口测试
```java
/**
*
* 测试
*
*
* @author yangkai.shen
* @date Created in 2019-09-30 10:30
*/
@Slf4j
@RestController
public class TestController {
@RateLimiter(value = 5)
@GetMapping("/test1")
public Dict test1() {
log.info("【test1】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "别想一直看到我,不信你快速刷新看看~");
}
@GetMapping("/test2")
public Dict test2() {
log.info("【test2】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "我一直都在,卟离卟弃");
}
@RateLimiter(value = 2, key = "测试自定义key")
@GetMapping("/test3")
public Dict test3() {
log.info("【test3】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "别想一直看到我,不信你快速刷新看看~");
}
}
```
### 1.6. 其余代码参见 demo
## 2. 测试
- 触发限流时控制台打印

- 触发限流的时候 Redis 的数据

## 3. 参考
- [mica-plus-redis 的分布式限流实现](https://github.com/lets-mica/mica/tree/master/mica-plus-redis)
- [Java并发:分布式应用限流 Redis + Lua 实践](https://segmentfault.com/a/1190000016042927)