@@ -0,0 +1,98 @@ | |||||
package com.xkcoding.ratelimit.redis.aspect; | |||||
import cn.hutool.core.util.StrUtil; | |||||
import com.xkcoding.ratelimit.redis.annotation.RateLimiter; | |||||
import com.xkcoding.ratelimit.redis.util.IpUtil; | |||||
import lombok.RequiredArgsConstructor; | |||||
import lombok.extern.slf4j.Slf4j; | |||||
import org.aspectj.lang.ProceedingJoinPoint; | |||||
import org.aspectj.lang.annotation.Around; | |||||
import org.aspectj.lang.annotation.Aspect; | |||||
import org.aspectj.lang.annotation.Pointcut; | |||||
import org.aspectj.lang.reflect.MethodSignature; | |||||
import org.springframework.beans.factory.annotation.Autowired; | |||||
import org.springframework.core.annotation.AnnotationUtils; | |||||
import org.springframework.data.redis.core.StringRedisTemplate; | |||||
import org.springframework.data.redis.core.script.RedisScript; | |||||
import org.springframework.stereotype.Component; | |||||
import java.lang.reflect.Method; | |||||
import java.time.Instant; | |||||
import java.util.Collections; | |||||
import java.util.concurrent.TimeUnit; | |||||
/** | |||||
* <p> | |||||
* 限流切面 | |||||
* </p> | |||||
* | |||||
* @author yangkai.shen | |||||
* @date Created in 2019/9/30 10:30 | |||||
*/ | |||||
@Slf4j | |||||
@Aspect | |||||
@Component | |||||
@RequiredArgsConstructor(onConstructor_ = @Autowired) | |||||
public class RateLimiterAspect { | |||||
private final static String SEPARATOR = ":"; | |||||
private final static String REDIS_LIMIT_KEY_PREFIX = "limit:"; | |||||
private final StringRedisTemplate stringRedisTemplate; | |||||
private final RedisScript<Long> limitRedisScript; | |||||
@Pointcut("@annotation(com.xkcoding.ratelimit.redis.annotation.RateLimiter)") | |||||
public void rateLimit() { | |||||
} | |||||
@Around("rateLimit()") | |||||
public Object pointcut(ProceedingJoinPoint point) throws Throwable { | |||||
MethodSignature signature = (MethodSignature) point.getSignature(); | |||||
Method method = signature.getMethod(); | |||||
// 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解 | |||||
RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class); | |||||
if (rateLimiter != null) { | |||||
String key = rateLimiter.key(); | |||||
// 默认用方法名做限流的 key 前缀 | |||||
if (StrUtil.isBlank(key)) { | |||||
key = method.getName(); | |||||
} | |||||
// 最终限流的 key 为 前缀 + IP地址 | |||||
// TODO: 此时需要考虑局域网多用户访问的情况,因此 key 后续需要加上方法参数更加合理 | |||||
key = key + SEPARATOR + IpUtil.getIpAddr(); | |||||
long max = rateLimiter.max(); | |||||
long timeout = rateLimiter.timeout(); | |||||
TimeUnit timeUnit = rateLimiter.timeUnit(); | |||||
boolean limited = shouldLimited(key, max, timeout, timeUnit); | |||||
if (limited) { | |||||
throw new RuntimeException("手速太快了,慢点儿吧~"); | |||||
} | |||||
} | |||||
return point.proceed(); | |||||
} | |||||
private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) { | |||||
// 最终的 key 格式为: | |||||
// limit:自定义key:IP | |||||
// limit:方法名:IP | |||||
key = REDIS_LIMIT_KEY_PREFIX + key; | |||||
// 统一使用单位毫秒 | |||||
long ttl = timeUnit.toMillis(timeout); | |||||
// 当前时间毫秒数 | |||||
long now = Instant.now().toEpochMilli(); | |||||
long expired = now - ttl; | |||||
// 注意这里必须转为 String,否则会报错 java.lang.Long cannot be cast to java.lang.String | |||||
Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + ""); | |||||
if (executeTimes != null) { | |||||
if (executeTimes == 0) { | |||||
log.error("【{}】在单位时间 {} 毫秒内已达到访问上限,当前接口上限 {}", key, ttl, max); | |||||
return true; | |||||
} else { | |||||
log.info("【{}】在单位时间 {} 毫秒内访问 {} 次", key, ttl, executeTimes); | |||||
return false; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
} |
@@ -0,0 +1,28 @@ | |||||
package com.xkcoding.ratelimit.redis.config; | |||||
import org.springframework.context.annotation.Bean; | |||||
import org.springframework.context.annotation.Configuration; | |||||
import org.springframework.core.io.ClassPathResource; | |||||
import org.springframework.data.redis.core.script.DefaultRedisScript; | |||||
import org.springframework.data.redis.core.script.RedisScript; | |||||
import org.springframework.scripting.support.ResourceScriptSource; | |||||
/** | |||||
* <p> | |||||
* Redis 配置 | |||||
* </p> | |||||
* | |||||
* @author yangkai.shen | |||||
* @date Created in 2019/9/30 11:37 | |||||
*/ | |||||
@Configuration | |||||
public class RedisConfig { | |||||
@Bean | |||||
@SuppressWarnings("unchecked") | |||||
public RedisScript<Long> limitRedisScript() { | |||||
DefaultRedisScript redisScript = new DefaultRedisScript<>(); | |||||
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis/limit.lua"))); | |||||
redisScript.setResultType(Long.class); | |||||
return redisScript; | |||||
} | |||||
} |
@@ -0,0 +1,59 @@ | |||||
package com.xkcoding.ratelimit.redis.util; | |||||
import cn.hutool.core.util.StrUtil; | |||||
import lombok.extern.slf4j.Slf4j; | |||||
import org.springframework.web.context.request.RequestContextHolder; | |||||
import org.springframework.web.context.request.ServletRequestAttributes; | |||||
import javax.servlet.http.HttpServletRequest; | |||||
/** | |||||
* <p> | |||||
* IP 工具类 | |||||
* </p> | |||||
* | |||||
* @author yangkai.shen | |||||
* @date Created in 2019/9/30 10:38 | |||||
*/ | |||||
@Slf4j | |||||
public class IpUtil { | |||||
private final static String UNKNOWN = "unknown"; | |||||
private final static int MAX_LENGTH = 15; | |||||
/** | |||||
* 获取IP地址 | |||||
* 使用Nginx等反向代理软件, 则不能通过request.getRemoteAddr()获取IP地址 | |||||
* 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,X-Forwarded-For中第一个非unknown的有效IP字符串,则为真实IP地址 | |||||
*/ | |||||
public static String getIpAddr() { | |||||
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest(); | |||||
String ip = null; | |||||
try { | |||||
ip = request.getHeader("x-forwarded-for"); | |||||
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { | |||||
ip = request.getHeader("Proxy-Client-IP"); | |||||
} | |||||
if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) { | |||||
ip = request.getHeader("WL-Proxy-Client-IP"); | |||||
} | |||||
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { | |||||
ip = request.getHeader("HTTP_CLIENT_IP"); | |||||
} | |||||
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { | |||||
ip = request.getHeader("HTTP_X_FORWARDED_FOR"); | |||||
} | |||||
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { | |||||
ip = request.getRemoteAddr(); | |||||
} | |||||
} catch (Exception e) { | |||||
log.error("IPUtils ERROR ", e); | |||||
} | |||||
// 使用代理,则获取第一个IP地址 | |||||
if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) { | |||||
if (ip.indexOf(StrUtil.COMMA) > 0) { | |||||
ip = ip.substring(0, ip.indexOf(StrUtil.COMMA)); | |||||
} | |||||
} | |||||
return ip; | |||||
} | |||||
} |
@@ -2,3 +2,20 @@ server: | |||||
port: 8080 | port: 8080 | ||||
servlet: | servlet: | ||||
context-path: /demo | context-path: /demo | ||||
spring: | |||||
redis: | |||||
host: localhost | |||||
# 连接超时时间(记得添加单位,Duration) | |||||
timeout: 10000ms | |||||
# Redis默认情况下有16个分片,这里配置具体使用的分片 | |||||
# database: 0 | |||||
lettuce: | |||||
pool: | |||||
# 连接池最大连接数(使用负值表示没有限制) 默认 8 | |||||
max-active: 8 | |||||
# 连接池最大阻塞等待时间(使用负值表示没有限制) 默认 -1 | |||||
max-wait: -1ms | |||||
# 连接池中的最大空闲连接 默认 8 | |||||
max-idle: 8 | |||||
# 连接池中的最小空闲连接 默认 0 | |||||
min-idle: 0 |
@@ -0,0 +1,27 @@ | |||||
-- 下标从 1 开始 | |||||
local key = KEYS[1] | |||||
local now = tonumber(ARGV[1]) | |||||
local ttl = tonumber(ARGV[2]) | |||||
local expired = tonumber(ARGV[3]) | |||||
-- 最大访问量 | |||||
local max = tonumber(ARGV[4]) | |||||
-- 清除过期的数据 | |||||
-- 移除指定分数区间内的所有元素,expired 即已经过期的 score | |||||
-- 根据当前时间毫秒数 - 超时毫秒数,得到过期时间 expired | |||||
redis.call('zremrangebyscore', key, 0, expired) | |||||
-- 获取 zset 中的元素个数 | |||||
local current = tonumber(redis.call('zcard', key)) | |||||
-- | |||||
local next = current + 1 | |||||
if next > max then | |||||
-- 达到限流大小 返回 0 | |||||
return 0; | |||||
else | |||||
-- 往 zset 中添加一个值、得分均为当前时间戳的元素,[value,score] | |||||
redis.call("zadd", key, now, now) | |||||
-- 每次访问均重新设置 zset 的过期时间,单位毫秒 | |||||
redis.call("pexpire", key, ttl) | |||||
return next | |||||
end |