背景

记一次惨痛的教训:

一次节日临近,产品经理兴致勃勃地组织团队研发抢红包活动。小目标定下来之后,研发团队加班赶进度,终于软件成功上线。不幸的是,后来发现有人疯狂刷接口,一笔红包提现N次,累计被盗走好几万资金。

参与研发的同事被老板叫到办公室骂了个狗血喷头,还被罚了款 。

故事结尾,CTO让我来解决这个并发安全问题。故事从这里开始…

什么是分布式锁?

线程锁和进程锁仅能满足在单机jvm或者同一个操作系统下,才能有效,跨jvm无法满足。因此就产生了分布式锁,其目的可以概括为:控制多个分布式节点的线程对同一项资源的互斥访问。

一个完善的分布式锁需要实现哪些目标?

  • 分布式环境下的互斥功能
  • 锁在事务的外围
  • 防止死锁的发生
  • 锁等待
  • 锁重入
  • 锁续约
  • 使用方便

手写分布式锁——基础版

该版本的加锁和业务逻辑代码混在一起,按照加锁->业务逻辑->解锁的顺序执行。
基于lua脚本和setnx命令实现,保证加锁过程的原子性。
核心代码:

1
2
3
4
5
6
lock(key, value, seconds);
try {
//do business
} finally {
unlock(key, value);
}

加锁:

1
2
3
4
5
6
7
8
9
10
11
12
boolean lock(String key, String value, long seconds){
String script = "if redis.call('SETNX', KEYS[1], ARGV[1]) == 1" +
" then if redis.call('EXPIRE', KEYS[1], ARGV[2]) == 1" +
" then return 1" +
" else redis.call('DEL', KEYS[1])" +
" end" +
" end" +
" return 0";
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<Long>(script, Long.class);
Long result = stringRedisTemplate.execute(redisScript, Arrays.asList(key), value, String.valueOf(seconds));
return result != null && result == 1;
}

解锁:

1
2
3
4
5
6
7
8
9
10
11
boolean unlock(String key, String value) {
String script = "if redis.call('get', KEYS[1]) == ARGV[1]" +
" then" +
" return redis.call('del', KEYS[1])" +
" else" +
" return 0" +
" end";
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<Long>(script, Long.class);
Long result = stringRedisTemplate.execute(redisScript, Arrays.asList(key), value);
return result != null && result == 1;
}

但是基础版实现分布式锁,无法实现锁等待、锁重入、锁续约,且使用也不方便,还是和业务逻辑耦合。

手写分布式锁——注解版

该版本通过注解实现。
加锁、解锁底层访问redis的方法和上一个版本相同。业务加锁一行注解代码轻松搞定,调用加锁、解锁方法的代码搬到切面中。为了支持所有用户线程都有机会拿到锁,补充了大量排队等待的逻辑。
使用方式:

1
2
3
4
5
@DistributedLock(keyPrefix = "myTestBusiness", keyField = "dto.id")
@Transactional(rollbackFor = Exception.class)
public void myTest(MyTestDTO dto) {
//do business
}

注解类:

1
2
3
4
5
6
7
8
9
10
11
12
@Retention(RetentionPolicy.RUNTIME)
@Target(value = {ElementType.METHOD})
public @interface DistributedLock {
/** 锁前缀。不同业务不能相同 */
String keyPrefix();
/** 关键字段。方法的参数或者参数的字段。举例: 1、id 2、user.id */
String keyField();
/** 锁的过期时长,单位:毫秒*/
long expireMillis() default 10000L;
/** 获取锁的等待超时时间,单位:毫秒。默认值:0不开启*/
long waitTimeoutMillis() default 0L;
}

分布式锁的切面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Aspect
@Order(Ordered.HIGHEST_PRECEDENCE)//保证该切面在事务切面的外围,释放锁必须在事务提交之后
@Component
public class DistributedLockAspect {
@Around(“@annotation(com.jrit.zlzp.framework.concurrent.DistributedLock)”)
public Object around(JoinPoint joinPoint) throws Throwable {
DistributedLock lock = ;//拿到注解对象
Worker worker = new Worker(key, lock.expireMillis(), lock.waitTimeoutMillis()); //构建任务
//加锁
if(!lock(worker.getKey(), worker.getValue(), worker.getExpireMillis()/1000)) {
//加锁失败
workQueueMap.get(key).offer(worker); //入队列排队
LockSupport.parkNanos(worker.waitTimeoutMillis*1000000); //阻塞线程
}
//处理业务
Object object = null;
try {
object = ((ProceedingJoinPoint)joinPoint).proceed();
} catch (Throwable throwable) {
throw throwable;
} finally {
unlock(worker.getKey(), worker.getValue()); //解锁
worker.awakeNext(key); //唤醒下一任务
}
return object;
}

该版本为了提升用户体验,保证每个用户线程都有机会拿到锁,采取了2个办法。
办法一:唤醒线程。当前占有锁的线程释放锁之后调用唤醒线程方法,令队列头部的等待线程尝试加锁。

1
2
3
4
5
6
7
8
9
void awakeNext(String key) {
LinkedBlockingQueue<Worker> queue = workQueueMap.get(key);
Worker worker = queue.peek();//取队头元素
//获取分布式锁
if(lock(worker.getKey(), worker.getValue(), worker.getExpireMillis()/1000)) {
queue.poll();
LockSupport.unpark(worker.thread);//唤醒业务线程
}
}

办法二:监视线程。有一个守护线程,用自旋的方式对全部队列进行扫描 ,令每个队列头部的等待线程尝试加锁。办法二是对一的补充,使得容错性更佳。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void watch() {
while(true) {
//对每一个key,从对应的等待队列中取出队头元素,尝试获取锁
workQueueMap.keySet().parallelStream().forEach(key -> {
LinkedBlockingQueue<Worker> queue = workQueueMap.get(key);
Worker worker = queue.peek();//取队头元素
//获取分布式锁
if(lock(worker.getKey(), worker.getValue(), worker.getExpireMillis()/1000)) {
queue.poll();//队头元素出队
LockSupport.unpark(worker.thread);//唤醒业务线程
}
});
}
}

注解版除了不能实现锁续约以外,其他都实现了。

手写分布式锁——注解PLUS版

该版本在注解实现的基础上,引入redisson,分布式锁机制更加完善和简洁易用。打开了新世界的大门!
注解类去掉过期时间字段,上一个版本中大量的排队、监视线程、唤醒线程以及lua脚本加锁的代码全部干掉,把注解切面中加锁、解锁的代码替换为redisson的锁方法,即完成了完美的蜕变。
注解的使用方式不变:

1
2
3
4
5
@DistributedLock(keyPrefix = "myTestBusiness", keyField = "dto.id")
@Transactional(rollbackFor = Exception.class)
public void myTest(MyTestDTO dto) {
//do business
}

分布式锁的切面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@Aspect
@Order(Ordered.HIGHEST_PRECEDENCE)//保证该切面在事务切面的外围,解锁必须在事务提交之后
@Component
@Slf4j
public class DistributedLockAspect {

@Around("@annotation(com.jrit.zlzp.framework.concurrent.DistributedLock)")
public Object around(JoinPoint joinPoint) throws Throwable {
//获取分布式key
Signature signature = joinPoint.getSignature();
DistributedLock lock = ((MethodSignature) signature).getMethod().getAnnotation(DistributedLock.class);
Asserts.notEmpty(lock.keyPrefix(), "keyPrefix");
Asserts.notEmpty(lock.keyField(), "keyField");
String[] keyFields = lock.keyField().split("\\.");
if (keyFields.length > 2) {
throw new CustomException("为了保证性能,keyField的层级不要超过2");
}
//小提示,帮助使用者快速上手
if (lock.waitTimeoutMillis() < 0) {
throw new CustomException("锁的等待超时时间非法");
}

String[] paramNames = ((CodeSignature) signature).getParameterNames();
Object[] params = joinPoint.getArgs();
String value = null;
for (int i = 0; i < paramNames.length; i++) {
if (keyFields[0].equals(paramNames[i])) {
if (keyFields.length == 1) {
value = String.valueOf(params[i]);
} else {
Field field = params[i].getClass().getDeclaredField(keyFields[1]);
field.setAccessible(true);
value = String.valueOf(field.get(params[i]));
}
break;
}
}
if (!StringUtils.hasText(value) || "null".equals(value)) {
throw new CustomException(lock.keyField() + "不能为空");
}
String key = lock.keyPrefix() + ":" + value;
//加锁
if (!RedisUtil.lock(key, lock.waitTimeoutMillis() / 1000)) {
throw new CustomException("系统繁忙,请稍后再试");
}
//处理业务
Object object = null;
try {
object = ((ProceedingJoinPoint) joinPoint).proceed();
} catch (Throwable throwable) {
throw throwable;
} finally {
//解锁
try {
RedisUtil.unlock(key);
} catch (Exception e) {//吞噬异常。切面在事务外围,抛错也不会回滚事务
log.error("DistributedLockAspect.around unlock error===key:{}", key, e);
}
}
return object;
}

}

加锁方法,leaseTime=-1启用看门狗机制,锁会自动续期。

看门狗后台线程定期检查并续期,会检查锁是否被当前线程所持有,如果仍然持有,就把锁的过期时间重置为30s(看门狗默认超时时间),续期操作也是通过 Lua 脚本原子性实现,每10s会执行一次看门狗。因此只要业务还在执行,锁就不会过期。

线程1持有锁,如果Redisson客户端此时宕机,看门狗线程会停止,无法续期,30s后锁自动过期释放,其他客户端也可以获取锁,这样避免了死锁问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  private static RedisTemplate<String, Object> redisTemplate = null;

private static StringRedisTemplate stringRedisTemplate =
(StringRedisTemplate)SpringContextHolder.getApplicationContext().getBean("stringRedisTemplate");

private static RedissonClient redissonClient =
(RedissonClient)SpringContextHolder.getApplicationContext().getBean("redissonClient");

final static Integer COMMON_DEFAULT_TIME = 30;

// 获取连接
@SuppressWarnings("unchecked")
public static RedisTemplate<String, Object> getRedis() {
if (redisTemplate == null) {
synchronized (RedisUtil.class) {
if (redisTemplate == null) {
ApplicationContext wac = SpringContextHolder.getApplicationContext();
redisTemplate = (RedisTemplate<String, Object>) wac.getBean("redisTemplate");
}
}
}
return redisTemplate;
}

/**
* 加锁
* <br>支持分布式锁,建议与{@link #unlock(String key)}配合使用
* @param key
* @param waitSeconds 锁等待超时秒数
*/
public static boolean lock(String key, long waitSeconds){
RLock lock = redissonClient.getLock(key);
try {
// leaseTime=-1启用看门狗
boolean acquired = lock.tryLock(waitSeconds, -1, TimeUnit.SECONDS);
if (acquired) {
log.debug("获取分布式锁成功(看门狗机制), key: {}", key);
} else {
log.warn("获取分布式锁失败(超时), key: {}, waitSeconds: {}", key, waitSeconds);
}
return acquired;
} catch (Exception e) {
log.error("RedisUtil.lock error. key:{}", key, e);
}
return false;
}



解锁方法,注意需判断是否被当前线程所持有。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/**
* 安全解锁
* @param key 锁的key
* @return 是否解锁成功
*/
public static boolean unlock(String key) {
return unlock(redissonClient.getLock(key));
}

/**
* 安全解锁 - 使用RLock实例
* @param lock RLock实例
* @return 是否解锁成功
*/
public static boolean unlock(RLock lock) {
if (lock == null) {
log.warn("解锁失败: lock为null");
return false;
}

try {
// 检查锁状态
if (!lock.isLocked()) {
log.debug("锁未被持有, 无需解锁: {}", lock.getName());
return true;
}

if (lock.isHeldByCurrentThread()) {
lock.unlock();
log.debug("解锁成功: {}", lock.getName());
return true;
} else {
log.warn("解锁失败: 当前线程不是锁的持有者, lock: {}", lock.getName());
return false;
}

} catch (IllegalMonitorStateException e) {
log.warn("解锁失败: 非法监控状态, lock: {}", lock.getName(), e);
return false;
} catch (Exception e) {
log.error("解锁异常, lock: {}", lock.getName(), e);
return false;
}
}

下面分享一个预防被盗刷的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@DistributedLock(keyPrefix = "BANK_ACCOUNT_BIND_LOCK", keyField = "dto.userId", waitTimeoutMillis = 3000L)
@Transactional(rollbackFor = Exception.class)
public Map<String, String> bind(SysUserBankAccountDTO dto) {
AssertJrit.isTrue(StringUtils.isNotEmpty(dto.getName()) && StringUtils.isNotEmpty(dto.getIdCard())
&& StringUtils.isNotEmpty(dto.getAccountNo()) && StringUtils.isNotEmpty(dto.getPhone())
, "参数错误");
Map<String, String> result = new HashMap<>();
Long userId = dto.getUserId();

/****************************限流策略 start*******************************/
//每日总次数限制,防止付费接口被盗刷,造成资金损失
String key = String.format("BANK_ACCOUNT_BIND_NUM_TOTAL:%s"
, DateUtils.parseDateToStr(DateUtils.YYYYMMDD, new Date()));
Object object = null;
Integer bindCount = (object = RedisUtil.get(key)) != null ? (int) object : 0;
//更新用户总次数
RedisUtil.set(key, bindCount + 1, 120 * 3600);
if (bindCount >= 300) {
result.put("errorMsg", "绑定失败,请联系平台客服");
return result;
}

//用户每日次数限制,防止付费接口被盗刷,造成资金损失
key = String.format("BANK_ACCOUNT_BIND_NUM_USER:%s:%s"
, DateUtils.parseDateToStr(DateUtils.YYYYMMDD, new Date()), userId);
object = null;
bindCount = (object = RedisUtil.get(key)) != null ? (int) object : 0;
//更新用户每日次数
RedisUtil.set(key, bindCount + 1, 120 * 3600);
int bindCountLimit = 5;
if (bindCount >= bindCountLimit) {
result.put("errorMsg", String.format("绑定失败。每天最多可绑定%s次,当前剩余%s次", bindCountLimit, 0));
return result;
}
int bindCountRemain = bindCountLimit - bindCount - 1;
String errorMsg4NumLimit = String.format("。每天最多可绑定%s次,当前剩余%s次", bindCountLimit, bindCountRemain);
/****************************限流策略 end*******************************/

//银行卡基础信息查询(付费接口)
JSONObject jsonObject = paasService.bankCardInformation(dto.getAccountNo());
if (jsonObject == null || jsonObject.getInteger("code") != 1) {
result.put("errorMsg", "绑定失败,请核对信息后重试" + errorMsg4NumLimit);
return result;
} else if (jsonObject.getJSONObject("data").getInteger("accountType") == 2) {
result.put("errorMsg", "暂不支持绑定信用卡" + errorMsg4NumLimit);
return result;
}
//银行卡四要素核验(付费接口)
boolean isSuccess = paasService.bankCard4EVerification(
dto.getName(), dto.getIdCard(), dto.getAccountNo(), dto.getPhone());
if (!isSuccess) {
result.put("errorMsg", "绑定失败,请核对信息后重试" + errorMsg4NumLimit);
return result;
}
//查询旧数据
SysUserBankAccount accountOld = sysUserBankAccountService.getByUserId(userId);
if (accountOld != null) {
//删除旧数据
SysUserBankAccount account = SysUserBankAccount.builder()
.id(accountOld.getId()).isDeleted(System.currentTimeMillis())//兼容唯一索引
.build();
UserCacheUtils.setUpdateParam(account);
sysUserBankAccountService.updateById(account);
}
//新增
SysUserBankAccount account = BeanUtil.toBean(dto, SysUserBankAccount.class);
account.setId(null);
account.setUserId(userId);
account.setBankName(jsonObject.getJSONObject("data").getString("accountBank"));
UserCacheUtils.setCommonParam(account);
sysUserBankAccountService.save(account);

SysUser user = sysUserService.getByUserId(userId);

//签约小创云
//特殊处理:因为支付牌照申请不易,所以非生产环境的微工卡核身流程可能不可用,在绑定银行卡时签约小创云
if (!envConfig.isPrd()) {
if (!xcyAtomService.signContact(RealNameDTO.builder().name(user.getUserName())
.phone(user.getUserPhone()).idCard(user.getIdCard())
.idCardFrontUrl(user.getIdCardFrontUrl()).idCardBackUrl(user.getIdCardBackUrl()).build())) {
throw new CustomException("签约失败" + errorMsg4NumLimit);
}
}

return null;//兼容前端
}