分布式缓存

对于经常使用的数据,我们一般会使用 Redis 作为缓存机制,为了实现高可用,使用了3台Redis(没有设置集群,集群至少要6台)。

使用hash算法,存储的时候根据公式 h = hash(key)%机器节点数,h 为 Redis 对应的编号,取数据的时候也根据相同的公式取,因此一定可以从存储的机器中拿到想要的数据。但是使用这种策略可能会存在以下问题:

  • 假设有一台 Redis 服务器宕机了,此时每个 key 就要按照 h = hash(key)%(机器节点数-1) 重新计算
  • 假设要新增一台 Redis 服务器,此时每个 key 就要按照 h = hash(key)%(机器节点数+1) 重新计算

也就是说,如果服务节点有变更,会导致缓存失效,大量的 key 需要重新计算,在这期间如果有请求进来,就会直接打到数据库上,导致缓存雪崩。

一致性哈希算法

一致性哈希是讲整个哈希空间组织成一个虚拟的圆环,假设哈希函数 H 的值空间为 [0,2^32-1](哈希值是32位无符号整形)。

把服务器按照 IP 或者主机名作为关键字进行哈希,确定服务器在哈希环中的位置。

再使用哈希函数把数据对象映射到环上,数据从顺时针方向找,遇到的第一个服务器就是它定位到的服务器。

image-20231228103401543

结论:数据1、2存储服务器B上,数据3存储在服务器C上,数据4存储在服务器A上

容错性和可扩展性

假如这时候有服务器C宕机了呢?那么只有原本在B和C之间的数据会失效,重新定位到服务器A,其他数据节点的服务器不会发生变化。

image-20231228103647790

或者我们想新增一台服务器D呢?那么只有C和D之间的数据会失效,重新定位到服务器D,而其他的数据节点的存储服务器也不会发生任何变化。

image-20231228103831559

可以看出,一致性哈希算法对于节点的增减只会有一部分数据需要重新定位,不会导致大量的缓存失效。

虚拟节点

现实的业务场景中,节点不会分布得那么均匀,如果节点较少,可能会出现数据倾斜的情况。

观察下图,所有的数据全都定位到服务B上,无法实现负载均衡了。

image-20231228104239560

为了解决这种数据存储不平衡的问题,一致性哈希算法引入了虚拟节点机制,即对每个节点计算多个哈希值,每个计算结果位置都放置在对应节点中,这些节点称为虚拟节点。

image-20231228112641899

增加了虚拟节点到实际节点的映射,这样就能解决服务节点少时数据不平均的问题了。在实际应用中,通常将虚拟节点数设置为32甚至更大,因此即使很少的服务节点也能做到相对均匀的数据分布。

手撕源码

介绍完一致性哈希算法的概念和规则,接下来我们从源码的角度分析一致性哈希算法是怎么实现的。

哈希算法

首先确定项目中要使用的哈希算法,其中服务器和数据的映射都依赖哈希算法。

非加密算法:MurMurHash算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/**
* MurMurHash算法,是非加密HASH算法,性能很高,
* 比传统的CRC32,MD5,SHA-1(这两个算法都是加密HASH算法,复杂度本身就很高,带来的性能上的损害也不可避免)
* 等HASH算法要快很多,而且据说这个算法的碰撞率很低.
* http://murmurhash.googlepages.com/
*/
public static Long hash(String key) {
ByteBuffer buf = ByteBuffer.wrap(key.getBytes());
int seed = 0x1234ABCD;
ByteOrder byteOrder = buf.order();
buf.order(ByteOrder.LITTLE_ENDIAN);
long m = 0xc6a4a7935bd1e995L;
int r = 47;
long h = seed ^ (buf.remaining() * m);
long k;
while (buf.remaining() >= 8) {
k = buf.getLong();
k *= m;
k ^= k >>> r;
k *= m;
h ^= k;
h *= m;
}
if (buf.remaining() > 0) {
ByteBuffer finish = ByteBuffer.allocate(8).order(
ByteOrder.LITTLE_ENDIAN);
// for big-endian version, do this first:
// finish.position(8-buf.remaining());
finish.put(buf).rewind();
h ^= finish.getLong();
h *= m;
}
h ^= h >>> r;
h *= m;
h ^= h >>> r;
buf.order(byteOrder);
return h;
}

加密算法:md5

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/**
* get hash code on 2^32 ring (md5散列的方式计算hash值)
* @param key
* @return long
*/
public static long hash2(String key) {
// md5 byte
MessageDigest md5;
try {
md5 = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("MD5 not supported", e);
}
md5.reset();
byte[] keyBytes = null;
try {
keyBytes = key.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("Unknown string :" + key, e);
}
md5.update(keyBytes);
byte[] digest = md5.digest();
// hash code, Truncate to 32-bits
long hashCode = ((long) (digest[3] & 0xFF) << 24)
| ((long) (digest[2] & 0xFF) << 16)
| ((long) (digest[1] & 0xFF) << 8)
| (digest[0] & 0xFF);
long truncateHashCode = hashCode & 0xffffffffL;
return truncateHashCode;
}

节点映射

以有序 Map 的形式在内存中缓存每个节点的 Hash 值对应的物理节点信息,所以引入了 TreeMap 进行存储。

为了增加一致性哈希算法中的虚拟节点,在初始化节点映射的过程中,将计算出 实际节点*虚拟节点 的hash值,以 Hash 值为 key,以物理节点标识为 value,以有序 Map 的形式在内存中缓存,作为后续计算数据对象对应的物理节点时的查询数据。代码如下,virtualHash2RealNode 中缓存着所有虚拟节点 Hash 值对应的物理节点信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/**
* 虚拟节点数
*/
private final int NODE_NUM = 1000;

/**
* 映射到哈希环上的 虚拟节点+真实节点 (使用 红黑树 排序)
*/
private TreeMap<Long, String> virtualHash2RealNode = new TreeMap<Long, String>();

/**
* 初始化节点(引入虚拟节点)
* init consistency hash ring, put virtual node on the 2^64 ring
*/
public void initVirtual2RealRing(List<String> shards) {
this.shardNodes = shards;
for (String node : shardNodes) {
for (int i = 0; i < NODE_NUM; i++){
long hashCode = hash("SHARD-" + node + "-NODE-" + i);
virtualHash2RealNode.put(hashCode, node);
}
}
}

数据定位节点

已知 virtualHash2RealNode 中存放着物理节点的信息,使用 tailMap() 方法寻找到比该数据大的范围内的所有物理节点,返回第一个节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
/**
* 寻找数据所对应节点
* 从顺时针遇到的第一个节点
* get real node by key's hash on the 2^64
*/
public String getShardInfo(String key) {
long hashCode = hash(key);
SortedMap<Long, String> tailMap = virtualHash2RealNode.tailMap(hashCode);
if (tailMap.isEmpty()) {
return virtualHash2RealNode.get(virtualHash2RealNode.firstKey());
}
return virtualHash2RealNode.get(tailMap.firstKey());
}

工具类

一般在项目中,会把一致性哈希算法包装成工具类使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
public class ConsistencyHashUtil {

/**
* 实际节点
*/
private List<String> shardNodes;

/**
* 存储节点数
*/
private final int NODE_NUM = 1000;

/**
* 映射到哈希环上的 虚拟节点+真实节点 (使用 红黑树 排序)
*/
private TreeMap<Long, String> virtualHash2RealNode = new TreeMap<Long, String>();
/**
* 初始化节点(引入虚拟节点)
* init consistency hash ring, put virtual node on the 2^64 ring
*/
public void initVirtual2RealRing(List<String> shards) {
this.shardNodes = shards;
for (String node : shardNodes) {
for (int i = 0; i < NODE_NUM; i++){
long hashCode = hash("SHARD-" + node + "-NODE-" + i);
virtualHash2RealNode.put(hashCode, node);
}
}
}
/**
* 寻找数据所对应节点
* 从顺时针遇到的第一个节点
* get real node by key's hash on the 2^64
*/
public String getShardInfo(String key) {
long hashCode = hash(key);
SortedMap<Long, String> tailMap = virtualHash2RealNode.tailMap(hashCode);
if (tailMap.isEmpty()) {
return virtualHash2RealNode.get(virtualHash2RealNode.firstKey());
}
return virtualHash2RealNode.get(tailMap.firstKey());
}
/**
* 打印节点
* prinf ring virtual node info
*/
public void printMap() {
System.out.println(virtualHash2RealNode);
}
/**
* MurMurHash算法,是非加密HASH算法,性能很高,
* 比传统的CRC32,MD5,SHA-1(这两个算法都是加密HASH算法,复杂度本身就很高,带来的性能上的损害也不可避免)
* 等HASH算法要快很多,而且据说这个算法的碰撞率很低.
* http://murmurhash.googlepages.com/
*/
public static Long hash(String key) {
ByteBuffer buf = ByteBuffer.wrap(key.getBytes());
int seed = 0x1234ABCD;
ByteOrder byteOrder = buf.order();
buf.order(ByteOrder.LITTLE_ENDIAN);
long m = 0xc6a4a7935bd1e995L;
int r = 47;
long h = seed ^ (buf.remaining() * m);
long k;
while (buf.remaining() >= 8) {
k = buf.getLong();
k *= m;
k ^= k >>> r;
k *= m;
h ^= k;
h *= m;
}
if (buf.remaining() > 0) {
ByteBuffer finish = ByteBuffer.allocate(8).order(
ByteOrder.LITTLE_ENDIAN);
// for big-endian version, do this first:
// finish.position(8-buf.remaining());
finish.put(buf).rewind();
h ^= finish.getLong();
h *= m;
}
h ^= h >>> r;
h *= m;
h ^= h >>> r;
buf.order(byteOrder);
return h;
}
/**
* get hash code on 2^32 ring (md5散列的方式计算hash值)
* @param key
* @return long
*/
public static long hash2(String key) {
// md5 byte
MessageDigest md5;
try {
md5 = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("MD5 not supported", e);
}
md5.reset();
byte[] keyBytes = null;
try {
keyBytes = key.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("Unknown string :" + key, e);
}
md5.update(keyBytes);
byte[] digest = md5.digest();
// hash code, Truncate to 32-bits
long hashCode = ((long) (digest[3] & 0xFF) << 24)
| ((long) (digest[2] & 0xFF) << 16)
| ((long) (digest[1] & 0xFF) << 8)
| (digest[0] & 0xFF);
long truncateHashCode = hashCode & 0xffffffffL;
return truncateHashCode;
}
public static void main(String[] args) {
List<String> shards = new ArrayList<String>();
shards.add("consumer-uuid-2");
shards.add("consumer-uuid-1");
ConsistencyHashUtil sh = new ConsistencyHashUtil();
sh.initVirtual2RealRing(shards);
sh.printMap();
int consumer1 = 0;
int consumer2 = 0;
for (int i = 0; i < 10000; i++) {
String key = "consumer" + i;
System.out.println(hash(key) + ":" + sh.getShardInfo(key));
if ("consumer-uuid-1".equals(sh.getShardInfo(key))) {
consumer1++;
}
if ("consumer-uuid-2".equals(sh.getShardInfo(key))) {
consumer2++;
}
}
System.out.println("consumer1:" + consumer1);
System.out.println("consumer2:" + consumer2);
/*long start = System.currentTimeMillis();
for (int i = 0; i < 1000 * 1000 * 1000; i++) {
if (i % (100 * 1000 * 1000) == 0) {
System.out.println(i + ":" + hash("key1" + i));
}
}
long end = System.currentTimeMillis();
System.out.println(end - start);*/
}
}

hutool 工具包也有封装好一致性哈希算法的工具类,只需要传入复制的节点个数和节点对象就能初始化节点映射。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
/**
* 一致性Hash算法
* 算法详解:http://blog.csdn.net/sparkliang/article/details/5279393
* 算法实现:https://weblogs.java.net/blog/2007/11/27/consistent-hashing
* @author xiaoleilu
*
* @param <T> 节点类型
*/
public class ConsistentHash<T> implements Serializable{
private static final long serialVersionUID = 1L;

/** Hash计算对象,用于自定义hash算法 */
Hash32<Object> hashFunc;
/** 复制的节点个数 */
private final int numberOfReplicas;
/** 一致性Hash环 */
private final SortedMap<Integer, T> circle = new TreeMap<>();

/**
* 构造,使用Java默认的Hash算法
* @param numberOfReplicas 复制的节点个数,增加每个节点的复制节点有利于负载均衡
* @param nodes 节点对象
*/
public ConsistentHash(int numberOfReplicas, Collection<T> nodes) {
this.numberOfReplicas = numberOfReplicas;
this.hashFunc = key -> {
//默认使用FNV1hash算法
return HashUtil.fnvHash(key.toString());
};
//初始化节点
for (T node : nodes) {
add(node);
}
}

/**
* 构造
* @param hashFunc hash算法对象
* @param numberOfReplicas 复制的节点个数,增加每个节点的复制节点有利于负载均衡
* @param nodes 节点对象
*/
public ConsistentHash(Hash32<Object> hashFunc, int numberOfReplicas, Collection<T> nodes) {
this.numberOfReplicas = numberOfReplicas;
this.hashFunc = hashFunc;
//初始化节点
for (T node : nodes) {
add(node);
}
}

/**
* 增加节点<br>
* 每增加一个节点,就会在闭环上增加给定复制节点数<br>
* 例如复制节点数是2,则每调用此方法一次,增加两个虚拟节点,这两个节点指向同一Node
* 由于hash算法会调用node的toString方法,故按照toString去重
* @param node 节点对象
*/
public void add(T node) {
for (int i = 0; i < numberOfReplicas; i++) {
circle.put(hashFunc.hash32(node.toString() + i), node);
}
}

/**
* 移除节点的同时移除相应的虚拟节点
* @param node 节点对象
*/
public void remove(T node) {
for (int i = 0; i < numberOfReplicas; i++) {
circle.remove(hashFunc.hash32(node.toString() + i));
}
}

/**
* 获得一个最近的顺时针节点
* @param key 为给定键取Hash,取得顺时针方向上最近的一个虚拟节点对应的实际节点
* @return 节点对象
*/
public T get(Object key) {
if (circle.isEmpty()) {
return null;
}
int hash = hashFunc.hash32(key);
if (false == circle.containsKey(hash)) {
SortedMap<Integer, T> tailMap = circle.tailMap(hash); //返回此映射的部分视图,其键大于等于 hash
hash = tailMap.isEmpty() ? circle.firstKey() : tailMap.firstKey();
}
//正好命中
return circle.get(hash);
}
}

传入复制的节点个数和实际物理节点信息,实现一致性哈希。

1
2
3
4
 public static ConsistentHash<Node> makeProxyPool(List<OpenaiProxy> openaiProxies) {
List<Node> realNodes = openaiProxies.stream().map(item -> new Node(item.getHost(), item.getToken())).collect(Collectors.toList());
return new ConsistentHash<>(500, realNodes);
}