Mar
31
这是罗凯同学布置的 Golang 学习作业。
这题之前用 Python 刷过,用的是二分法,在 [1, n / 2] 区间内,找到第一个 x,使得 x ^ 2 <= n < (x + 1) ^ 2 ,用的是 STL 中 lowerbound 的算法。
罗凯同学提到,应该使用牛顿迭代法来完成。这个方法是听说过的,但是早就忘了,于是到 wikipedia 去找了一下:
https://zh.wikipedia.org/wiki/%E7%89%9B%E9%A1%BF%E6%B3%95
求函数 f(x) 的零点,可以通过选取曲线上的任意一个点 x0 开始,然后计算 x1 = x0 - f(x1) / f'(x1) 的方式迭代,通常得到一个比 x0 更接近零点的 x1 。通过不断迭代,最终我们能找到一个零点 xn 。
对于求平方根,我们是要找到一个 x,使得 x ^2 - n = 0,也就是这里的 f(x) = x ^ 2 - n, f'(x) = 2 * x (勉强还记得这个求导公式……)
有了这个,答案就呼之欲出了:
做完以后,我想起 Quake III 的作者 John Carmack 的 平方根倒数速算法,摘录一段内容:( src: https://blog.csdn.net/zyex1108/article/details/53540824 )
Quake-III Arena (雷神之锤3)是90年代的经典游戏之一。该系列的游戏不但画面和内容不错,而且即使计算机配置低,也能极其流畅地运行。这要归功于它3D引擎的开发者约翰-卡马克(John Carmack)。事实上早在90年代初DOS时代,只要能在PC上搞个小动画都能让人惊叹一番的时候,John Carmack就推出了石破天惊的Castle Wolfstein, 然后再接再励,doom, doomII, Quake...每次都把3-D技术推到极致。他的3D引擎代码资极度高效,几乎是在压榨PC机的每条运算指令。
这个平方根倒数算法正是其中的一个例子。在3D游戏引擎中,求取照明和投影的波动角度与反射效果时,常需计算平方根倒数,而求平方根的常用算法效率较低。
Carmack 通过使用一个惊为天人的魔术常量 0x5f3759df,只需要做 1 次迭代(Quaker III源码中的为了提高精度的第二次迭代被注视掉了),就能得到一个足够精度的平方根,大幅提高了 3D 引擎的运行效率。
关于这个魔术常量,Carmack 表示并不是他自己发明的,至今为止仍未能确切知晓算法中所使用的特殊常数的起源。但 Carmack 凭一己之力,撑起了一个 3D 引擎的时代,以至于在1999年,登上了美国时代杂志评选出来的科技领域50大影响力人物榜单,并且名列第10位。
感兴趣的同学,可以在 Wikipedia 的 平方根倒数速算法 了解更多细节:
https://zh.wikipedia.org/wiki/%E5%B9%B3%E6%96%B9%E6%A0%B9%E5%80%92%E6%95%B0%E9%80%9F%E7%AE%97%E6%B3%95
这题之前用 Python 刷过,用的是二分法,在 [1, n / 2] 区间内,找到第一个 x,使得 x ^ 2 <= n < (x + 1) ^ 2 ,用的是 STL 中 lowerbound 的算法。
class Solution(object):
def mySqrt(self, x):
"""
:type x: int
:rtype: int
"""
if x < 0:
raise Exception("invalid input")
if x < 2:
return x
left = 1
length = x / 2
while length > 1:
half = length / 2
middle = left + half
if middle * middle > x:
length = half
else:
left = middle
length = length - half
return left
def mySqrt(self, x):
"""
:type x: int
:rtype: int
"""
if x < 0:
raise Exception("invalid input")
if x < 2:
return x
left = 1
length = x / 2
while length > 1:
half = length / 2
middle = left + half
if middle * middle > x:
length = half
else:
left = middle
length = length - half
return left
罗凯同学提到,应该使用牛顿迭代法来完成。这个方法是听说过的,但是早就忘了,于是到 wikipedia 去找了一下:
https://zh.wikipedia.org/wiki/%E7%89%9B%E9%A1%BF%E6%B3%95
求函数 f(x) 的零点,可以通过选取曲线上的任意一个点 x0 开始,然后计算 x1 = x0 - f(x1) / f'(x1) 的方式迭代,通常得到一个比 x0 更接近零点的 x1 。通过不断迭代,最终我们能找到一个零点 xn 。
对于求平方根,我们是要找到一个 x,使得 x ^2 - n = 0,也就是这里的 f(x) = x ^ 2 - n, f'(x) = 2 * x (勉强还记得这个求导公式……)
有了这个,答案就呼之欲出了:
import "math"
func mySqrt(x int) int {
f := func (i float64) float64 {
return i * i - float64(x)
}
g := func (i float64) float64 {
return 2 * i
}
var i float64 = 1.0
for math.Abs(f(i)) > 1e-6 {
i = i - f(i) / g(i)
}
return int(math.Floor(i))
}
func mySqrt(x int) int {
f := func (i float64) float64 {
return i * i - float64(x)
}
g := func (i float64) float64 {
return 2 * i
}
var i float64 = 1.0
for math.Abs(f(i)) > 1e-6 {
i = i - f(i) / g(i)
}
return int(math.Floor(i))
}
做完以后,我想起 Quake III 的作者 John Carmack 的 平方根倒数速算法,摘录一段内容:( src: https://blog.csdn.net/zyex1108/article/details/53540824 )
引用
Quake-III Arena (雷神之锤3)是90年代的经典游戏之一。该系列的游戏不但画面和内容不错,而且即使计算机配置低,也能极其流畅地运行。这要归功于它3D引擎的开发者约翰-卡马克(John Carmack)。事实上早在90年代初DOS时代,只要能在PC上搞个小动画都能让人惊叹一番的时候,John Carmack就推出了石破天惊的Castle Wolfstein, 然后再接再励,doom, doomII, Quake...每次都把3-D技术推到极致。他的3D引擎代码资极度高效,几乎是在压榨PC机的每条运算指令。
这个平方根倒数算法正是其中的一个例子。在3D游戏引擎中,求取照明和投影的波动角度与反射效果时,常需计算平方根倒数,而求平方根的常用算法效率较低。
Carmack 通过使用一个惊为天人的魔术常量 0x5f3759df,只需要做 1 次迭代(Quaker III源码中的为了提高精度的第二次迭代被注视掉了),就能得到一个足够精度的平方根,大幅提高了 3D 引擎的运行效率。
关于这个魔术常量,Carmack 表示并不是他自己发明的,至今为止仍未能确切知晓算法中所使用的特殊常数的起源。但 Carmack 凭一己之力,撑起了一个 3D 引擎的时代,以至于在1999年,登上了美国时代杂志评选出来的科技领域50大影响力人物榜单,并且名列第10位。
感兴趣的同学,可以在 Wikipedia 的 平方根倒数速算法 了解更多细节:
https://zh.wikipedia.org/wiki/%E5%B9%B3%E6%96%B9%E6%A0%B9%E5%80%92%E6%95%B0%E9%80%9F%E7%AE%97%E6%B3%95
Jan
29
TLDR版本:https://leetcode-cn.com/explore/ ,注册一个帐号开始做题就行了。
== 以下是正文 ==
作为一个程序员,编码能力是基础的基础。
我比较幸运,在大学的时候参加了学校的 ACM/ICPC 集训队,接触了 ACM/ICPC 比赛。这是一个针对大学生编程能力的世界级比赛,要求在几个小时的时间里完成若干道不同难度的题目,其中很多题目不仅需要复杂的算法、有各种特殊情况需要考虑,而且还有变态级的效率要求。强如楼教主(楼天城),也仅在 2009 年获得世界总决赛的第二名。
此外,从我观测到的结果来看,但凡从集训队走出去的成员(无论其竞赛成绩如何),**其毕业后的第一份工作(通常都是 BAT )乃至之后的发展,都显著高于计算机专业的平均水平**。
虽然在集训队里有教练,也有大神,但日常学习主要还是靠自己。看书学习固然是一种方式,但是比较枯燥,也不容易衡量自己的学习成果。另一方面,由于赛事多年的发展和积淀,国内参赛实力较强的大学(例如北大、杭州电子科技大学、华中科技大学)都创建了自己的在线测评系统(英文名叫 Online Judge,简称 OJ)。
OJ 上沉淀了多年来的竞赛题目,每一个题目都包含相应的题面、输入说明、输出要求、基础测试用例;用户按要求编写代码后,将代码提交给 OJ,系统会在后台启动自动化测试,告知测评结果。
由于 OJ 系统的存在,做题变成了一种乐趣,通过努力解决了一个问题,系统会给出红色的 Accepted 字样,就像一种奖赏;而在这个过程中,也可以直接地看到自己的进步。
工作以后,我非常庆幸当年自己在 OJ 系统刷过这些题,夯实了编程能力,在工作中能够完成更高质量的代码。而在过去几年的面试过程中,我发现很多来应聘的程序员,往往只能应对简单的情况,处理不好边界问题、例外情况、运行效率带来的挑战。
遗憾的是,由于学校自建的 OJ 往往都是学生自己开发、自己维护(我也写过一个,维护过几年,深有体会),体验较差,对存量题目的组织、整理也比较随意(往往只是简单的罗列),而且由于比赛是英文环境,题面往往也都是纯英文的,给竞赛圈之外的同学带来了一定门槛。
所幸,近年来,第三方(商业公司、志愿者社区)的 OJ 系统也逐渐完善,其中一个我很喜欢的平台是 LeetCode ,大约成立于 2008 年吧,上面的题多是业内 TOP 公司的面试题,很多人通过刷这些题来应聘喜欢面试算法的 NTMGBA 系列公司(注:Netease,Tencent,Microsoft,Google,Baidu,Alibaba/Amazon)。
相比各个学校维护的 OJ 平台,LeetCode 的体验令人称道:
* 支持多种语言,包括 PHP、Python、Go、Rust、Javascript,甚至还有基于 MySQL 的题目
* 推出了完整的中文版,包括纯中文的题面
* 对题目做了细致的整理,打上各种标签,包括难度(简单、中等、困难)、话题(字符串、堆/栈、贪心算法、动态规划等)
* 通过合集的方式,将题目整理归档(例如腾讯精选50题、初级算法、中级算法等)
* 对于错误的情况,给出明确的错误原因,及相应的输入输出数据,方便自我纠正
* 许多题目有详尽的官方解答,即使不会做也能够直接学习
LeetCode 上的题目大致可以分成两种(参考 CoolShell 博客说明):
1. 算法题。大多是套路题,每道题目都需要特定的算法,例如BFS、DFS、动态规划、回溯等。通过做这些题,能够让自己对这些最基础的算法的思路有非常扎实的了解和训练,也能很好地锻炼自己的思维能力(烧脑)。
2. 编程题。比如:atoi,strstr,add two num,括号匹配,字符串乘法,通配符匹配,文件路径简化,Text Justification,反转单词等等。这些题目的题面都很简单,大部分程序员都能读懂,但是魔鬼藏在细节中,具体的实现往往需要考虑多种情况。通过做这些题,可以非常好的训练自己对各种情况的考虑,以及对程序代码组织的掌控能力(其实就是其中的状态变量)。程序中的状态正是程序变得复杂难维护的直接原因。
每个程序员内心都有一个大神梦,但是别忘了,大神也是从菜鸟一步一个脚印走过来的,而 LeetCode 就是一个很好的垫脚石,共勉。
== 以下是正文 ==
作为一个程序员,编码能力是基础的基础。
我比较幸运,在大学的时候参加了学校的 ACM/ICPC 集训队,接触了 ACM/ICPC 比赛。这是一个针对大学生编程能力的世界级比赛,要求在几个小时的时间里完成若干道不同难度的题目,其中很多题目不仅需要复杂的算法、有各种特殊情况需要考虑,而且还有变态级的效率要求。强如楼教主(楼天城),也仅在 2009 年获得世界总决赛的第二名。
此外,从我观测到的结果来看,但凡从集训队走出去的成员(无论其竞赛成绩如何),**其毕业后的第一份工作(通常都是 BAT )乃至之后的发展,都显著高于计算机专业的平均水平**。
虽然在集训队里有教练,也有大神,但日常学习主要还是靠自己。看书学习固然是一种方式,但是比较枯燥,也不容易衡量自己的学习成果。另一方面,由于赛事多年的发展和积淀,国内参赛实力较强的大学(例如北大、杭州电子科技大学、华中科技大学)都创建了自己的在线测评系统(英文名叫 Online Judge,简称 OJ)。
OJ 上沉淀了多年来的竞赛题目,每一个题目都包含相应的题面、输入说明、输出要求、基础测试用例;用户按要求编写代码后,将代码提交给 OJ,系统会在后台启动自动化测试,告知测评结果。
由于 OJ 系统的存在,做题变成了一种乐趣,通过努力解决了一个问题,系统会给出红色的 Accepted 字样,就像一种奖赏;而在这个过程中,也可以直接地看到自己的进步。
工作以后,我非常庆幸当年自己在 OJ 系统刷过这些题,夯实了编程能力,在工作中能够完成更高质量的代码。而在过去几年的面试过程中,我发现很多来应聘的程序员,往往只能应对简单的情况,处理不好边界问题、例外情况、运行效率带来的挑战。
遗憾的是,由于学校自建的 OJ 往往都是学生自己开发、自己维护(我也写过一个,维护过几年,深有体会),体验较差,对存量题目的组织、整理也比较随意(往往只是简单的罗列),而且由于比赛是英文环境,题面往往也都是纯英文的,给竞赛圈之外的同学带来了一定门槛。
所幸,近年来,第三方(商业公司、志愿者社区)的 OJ 系统也逐渐完善,其中一个我很喜欢的平台是 LeetCode ,大约成立于 2008 年吧,上面的题多是业内 TOP 公司的面试题,很多人通过刷这些题来应聘喜欢面试算法的 NTMGBA 系列公司(注:Netease,Tencent,Microsoft,Google,Baidu,Alibaba/Amazon)。
相比各个学校维护的 OJ 平台,LeetCode 的体验令人称道:
* 支持多种语言,包括 PHP、Python、Go、Rust、Javascript,甚至还有基于 MySQL 的题目
* 推出了完整的中文版,包括纯中文的题面
* 对题目做了细致的整理,打上各种标签,包括难度(简单、中等、困难)、话题(字符串、堆/栈、贪心算法、动态规划等)
* 通过合集的方式,将题目整理归档(例如腾讯精选50题、初级算法、中级算法等)
* 对于错误的情况,给出明确的错误原因,及相应的输入输出数据,方便自我纠正
* 许多题目有详尽的官方解答,即使不会做也能够直接学习
LeetCode 上的题目大致可以分成两种(参考 CoolShell 博客说明):
1. 算法题。大多是套路题,每道题目都需要特定的算法,例如BFS、DFS、动态规划、回溯等。通过做这些题,能够让自己对这些最基础的算法的思路有非常扎实的了解和训练,也能很好地锻炼自己的思维能力(烧脑)。
2. 编程题。比如:atoi,strstr,add two num,括号匹配,字符串乘法,通配符匹配,文件路径简化,Text Justification,反转单词等等。这些题目的题面都很简单,大部分程序员都能读懂,但是魔鬼藏在细节中,具体的实现往往需要考虑多种情况。通过做这些题,可以非常好的训练自己对各种情况的考虑,以及对程序代码组织的掌控能力(其实就是其中的状态变量)。程序中的状态正是程序变得复杂难维护的直接原因。
每个程序员内心都有一个大神梦,但是别忘了,大神也是从菜鸟一步一个脚印走过来的,而 LeetCode 就是一个很好的垫脚石,共勉。
Sep
19
# 1. 什么是跳表
跳表(Skip List)是基于链表 + 随机化实现的一个有序数据结构,可以达到平均 O(logN) 的查找、插入、删除效率,在实际运行中的效率往往超过 AVL 等平衡二叉树,而且其实现相对更简单、内存消耗更低。
Redis 的 ZSET 底层实现就是用的 Skip List,这里是 [Antirez对此的说明](https://news.ycombinator.com/item?id=1171423)。
这是一个典型的跳表:
解释一下:
1. SkipList 是一个多层的链表
2. 第[0]层的链表包含所有节点,其他层的链表包含部分节点,层次越高,节点越少
3. 每层链表之间会共享相同的节点(节省内存,但为了方便展示,每一层都输出了它的值)
4. 对于某个节点,在插入时通过概率判断它最高会出现在哪一层,并且也会出现在之下的每一层
通过这样的设计,当需要查找某个 key 时,可以从最高层的链表开始往前找,在这一层遇到末尾或者大于 key 的节点时往下走一个层,直到找到 key 节点。
例如:
# 2. 跳表的节点
从上面的描述,我们大概可以知道 (1) 每个节点需要保存一个 key; (2) 每个节点需要有多个next指针 (3) 其 next 指针的数量会在插入时确定
因此我们可以用下面这个 class 来表示节点:
# 3. 创建跳表
一个新创建的跳表是没有节点的。但为了实现的简单起见,可以添加一个头节点:
到目前为止都特别简单,但是还什么也干不了。
# 4. 创建节点
创建节点时,需要先按一定的概率分布确定其高度。
为了保证高层的节点比低层少,我们可以用这样的概率分布:
实现其实非常简单:
这样可以保证平均的路径长度是 log(n) 。
精确一点的话,实际上是 log(n-1, 1/p) / p,也就是说, p 的选择会影响跳表层数、平均路径长度。
具体的计算比较复杂,有兴趣可以参考跳表的原论文《Skip Lists: A Probabilistic Alternative to Balanced Trees》。(TL;DR)
然后我们就可以这样来创建一个新的节点:
node = Node(self.randomHeight(), key)
# 5. 添加节点
如果只是为空跳表添加一个新的节点,只要更新头结点的每一个next指针:
但很显然这个方法只能用一次。
如果跳表中已经有多个节点,那我们就必须找到每一层中适合插入的位置:
这个函数返回一个 update 节点数组,其中的每个节点都是在这一层中小于 key 的最后一个节点。
也就是说,在 level = i 层,总是可以把新的节点插入 update[i] 之后:
但是由于这一版 getUpdateList 是 O(n) 的,插入效率并没有达到跳表的设计目标。
# 6. 添加节点++
考虑这一点:跳表的每一层都是有序的。
也就是说,我们在找到 update[n] = x 以后,其实可以从节点 x 的 n - 1 层继续查找 update[n-1] 应该是哪个节点。
由于查找路径的平均长度是 log(N) ,所以我们可以实现一个更快的 getUpdateList 方法
注意,需要从最高层开始查
# 7. 里程碑1
把上面的代码整合起来,我们就可以得到第一版跳表代码:能够插入节点。
为了更好地展示我们的成果,我们可以用这样一个函数,把链表按第1节的例子样式输出:
试试看:
多尝试几次,以及选择不同的 p 值,可以观察生成跳表的区别。
# 8. 查找节点
实际上查找节点的过程,已经包含在 insert 的实现里了:
# 9. 删除节点
既然已经能找出 update 节点数组,在 level = i 层,只要判断 update[i].next[i] 是否等于要删除的 key 就可以了:
# 10. 里程碑2
整合 find 和 update 数组,就可以实现跳表的基础操作了,试试看:
# 11. 其他
我们在 Node 中只添加了一个 key 属性,在具体的实现中,我们往往可能需要针对 key 存储一个 value,例如 Python 自带的 dict 实现。改造起来也很简单:
1. node 中添加一个 value 属性,并且添加相应的初始化逻辑(__init__方法)
2. 将 SkipList.insert 修改为 `insert(self, key, value)`,在新建 Node 时指定其 value
3. 再添加一个 `update(self, key, value)` API,方便调用方的使用
4. 可以考虑针对语言适配,例如实现 python 的 __getitem__ 、 __setitem__ 等魔术方法
# 12. 完整代码
完。
跳表(Skip List)是基于链表 + 随机化实现的一个有序数据结构,可以达到平均 O(logN) 的查找、插入、删除效率,在实际运行中的效率往往超过 AVL 等平衡二叉树,而且其实现相对更简单、内存消耗更低。
Redis 的 ZSET 底层实现就是用的 Skip List,这里是 [Antirez对此的说明](https://news.ycombinator.com/item?id=1171423)。
这是一个典型的跳表:
[0] -> 0 -> 1 -> 3 -> 4 -> 5 -> 6 -> 7 -> 9 -> nil
[1] -> 0 ------> 3 ------> 5 ------> 7 ------> nil
[2]----------------------> 5-----------------> nil
[1] -> 0 ------> 3 ------> 5 ------> 7 ------> nil
[2]----------------------> 5-----------------> nil
解释一下:
1. SkipList 是一个多层的链表
2. 第[0]层的链表包含所有节点,其他层的链表包含部分节点,层次越高,节点越少
3. 每层链表之间会共享相同的节点(节省内存,但为了方便展示,每一层都输出了它的值)
4. 对于某个节点,在插入时通过概率判断它最高会出现在哪一层,并且也会出现在之下的每一层
通过这样的设计,当需要查找某个 key 时,可以从最高层的链表开始往前找,在这一层遇到末尾或者大于 key 的节点时往下走一个层,直到找到 key 节点。
例如:
引用
4 的查找路径为 [2] -> [1] -> 0 -> 3 -> 3@[0] -> 4
6 的查找路径为 [2] -> 5 -> 5@[1] -> 5@[0] -> 6
8 的查找路径为 [2] -> 5 -> 5@[1] -> 7 -> 7@[0] -> 9 (找不到)
6 的查找路径为 [2] -> 5 -> 5@[1] -> 5@[0] -> 6
8 的查找路径为 [2] -> 5 -> 5@[1] -> 7 -> 7@[0] -> 9 (找不到)
# 2. 跳表的节点
从上面的描述,我们大概可以知道 (1) 每个节点需要保存一个 key; (2) 每个节点需要有多个next指针 (3) 其 next 指针的数量会在插入时确定
因此我们可以用下面这个 class 来表示节点:
class Node(object)
def __init__(self, height, key):
self.key = key
self.next = [None] * height
def height(self):
return len(self.next)
def __init__(self, height, key):
self.key = key
self.next = [None] * height
def height(self):
return len(self.next)
# 3. 创建跳表
一个新创建的跳表是没有节点的。但为了实现的简单起见,可以添加一个头节点:
class SkipList(object):
def __init__(self):
self.head = Node(0, None) #头节点高度为0,不需要key
def __init__(self):
self.head = Node(0, None) #头节点高度为0,不需要key
到目前为止都特别简单,但是还什么也干不了。
# 4. 创建节点
创建节点时,需要先按一定的概率分布确定其高度。
为了保证高层的节点比低层少,我们可以用这样的概率分布:
引用
Height(n) = p^n
实现其实非常简单:
import random
def randomHeight(self, p = 0.5):
height = 1
while random.uniform(0, 1) < p and self.head.height() >= height:
height += 1
return height
def randomHeight(self, p = 0.5):
height = 1
while random.uniform(0, 1) < p and self.head.height() >= height:
height += 1
return height
这样可以保证平均的路径长度是 log(n) 。
精确一点的话,实际上是 log(n-1, 1/p) / p,也就是说, p 的选择会影响跳表层数、平均路径长度。
具体的计算比较复杂,有兴趣可以参考跳表的原论文《Skip Lists: A Probabilistic Alternative to Balanced Trees》。(TL;DR)
然后我们就可以这样来创建一个新的节点:
node = Node(self.randomHeight(), key)
# 5. 添加节点
如果只是为空跳表添加一个新的节点,只要更新头结点的每一个next指针:
def insertFirstNode(self, key):
node = Node(self.randomHeight(), key)
while node.height > self.head.height():
self.head.next.append(None) #保证头节点的next数组覆盖所有层次的链表
for level in range(node.height()):
node.next[level] = self.head.next[level]
self.head.next[level] = node
node = Node(self.randomHeight(), key)
while node.height > self.head.height():
self.head.next.append(None) #保证头节点的next数组覆盖所有层次的链表
for level in range(node.height()):
node.next[level] = self.head.next[level]
self.head.next[level] = node
但很显然这个方法只能用一次。
如果跳表中已经有多个节点,那我们就必须找到每一层中适合插入的位置:
def getUpdateList(self, key):
update = [None] * self.head.height()
for level in range(len(update)):
x = self.head
while x.next[level] is not None and x.next[level].key < key:
x = x.next[level]
update[level] = x
return update
update = [None] * self.head.height()
for level in range(len(update)):
x = self.head
while x.next[level] is not None and x.next[level].key < key:
x = x.next[level]
update[level] = x
return update
这个函数返回一个 update 节点数组,其中的每个节点都是在这一层中小于 key 的最后一个节点。
也就是说,在 level = i 层,总是可以把新的节点插入 update[i] 之后:
def insert(self, key):
node = Node(self.randomHeight(), key)
while node.height > self.head.height():
self.head.next.append(None) #保证头节点的next数组覆盖所有层次的链表
update = self.getUpdateList(key)
next0 = update[0].next[0]
if next0 is not None and next0.key == key:
return # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
for level in range(node.height()):
node.next[level] = update[level].next[level]
update[level].next[level] = node
node = Node(self.randomHeight(), key)
while node.height > self.head.height():
self.head.next.append(None) #保证头节点的next数组覆盖所有层次的链表
update = self.getUpdateList(key)
next0 = update[0].next[0]
if next0 is not None and next0.key == key:
return # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
for level in range(node.height()):
node.next[level] = update[level].next[level]
update[level].next[level] = node
但是由于这一版 getUpdateList 是 O(n) 的,插入效率并没有达到跳表的设计目标。
# 6. 添加节点++
考虑这一点:跳表的每一层都是有序的。
也就是说,我们在找到 update[n] = x 以后,其实可以从节点 x 的 n - 1 层继续查找 update[n-1] 应该是哪个节点。
由于查找路径的平均长度是 log(N) ,所以我们可以实现一个更快的 getUpdateList 方法
注意,需要从最高层开始查
def getUpdateList(self, key):
update = [None] * self.head.height()
x = self.head
for level in reversed(range(len(update))):
while x.next[level] is not None and x.next[level].key < key:
x = x.next[level]
update[level] = x
return update
update = [None] * self.head.height()
x = self.head
for level in reversed(range(len(update))):
while x.next[level] is not None and x.next[level].key < key:
x = x.next[level]
update[level] = x
return update
# 7. 里程碑1
把上面的代码整合起来,我们就可以得到第一版跳表代码:能够插入节点。
为了更好地展示我们的成果,我们可以用这样一个函数,把链表按第1节的例子样式输出:
def dump(self):
for i in range(self.head.height()):
sys.stdout.write('[H]')
x = self.head.next[0]
y = self.head.next[i]
while x is not None:
s = ' -> %s' % x.key
if x is y:
y = y.next[i]
else:
s = '-' * len(s)
x = x.next[0]
sys.stdout.write(s)
print ' -> <nil>'
print
for i in range(self.head.height()):
sys.stdout.write('[H]')
x = self.head.next[0]
y = self.head.next[i]
while x is not None:
s = ' -> %s' % x.key
if x is y:
y = y.next[i]
else:
s = '-' * len(s)
x = x.next[0]
sys.stdout.write(s)
print ' -> <nil>'
试试看:
sl = SkipList()
for i in range(10):
sl.insert(sl)
s1.dump()
for i in range(10):
sl.insert(sl)
s1.dump()
[H] -> 0 -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 -> 9 -> <nil>
[H]----- -> 1 -> 2 -> 3---------- -> 6 -> 7---------- -> <nil>
[H]---------- -> 2-------------------- -> 7---------- -> <nil>
[H]----- -> 1 -> 2 -> 3---------- -> 6 -> 7---------- -> <nil>
[H]---------- -> 2-------------------- -> 7---------- -> <nil>
多尝试几次,以及选择不同的 p 值,可以观察生成跳表的区别。
# 8. 查找节点
实际上查找节点的过程,已经包含在 insert 的实现里了:
def find(self, key):
update = self.getUpdateList(key)
if len(update) == 0:
return None
next0 = update[0].next[0]
if next0 is not None and next0.key == key:
return next0 # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
else:
return None
update = self.getUpdateList(key)
if len(update) == 0:
return None
next0 = update[0].next[0]
if next0 is not None and next0.key == key:
return next0 # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
else:
return None
# 9. 删除节点
既然已经能找出 update 节点数组,在 level = i 层,只要判断 update[i].next[i] 是否等于要删除的 key 就可以了:
def remove(self, key):
update = self.getUpdateList(key)
for i, node in enumerate(update):
if node.next[i] is not None and node.next[i].key == key:
node.next[i] = node.next[i].next[i]
update = self.getUpdateList(key)
for i, node in enumerate(update):
if node.next[i] is not None and node.next[i].key == key:
node.next[i] = node.next[i].next[i]
# 10. 里程碑2
整合 find 和 update 数组,就可以实现跳表的基础操作了,试试看:
node = sl.find(3)
print node
for i in range(7, 14):
sl.remove(i)
sl.dump()
print node
for i in range(7, 14):
sl.remove(i)
sl.dump()
# 11. 其他
我们在 Node 中只添加了一个 key 属性,在具体的实现中,我们往往可能需要针对 key 存储一个 value,例如 Python 自带的 dict 实现。改造起来也很简单:
1. node 中添加一个 value 属性,并且添加相应的初始化逻辑(__init__方法)
2. 将 SkipList.insert 修改为 `insert(self, key, value)`,在新建 Node 时指定其 value
3. 再添加一个 `update(self, key, value)` API,方便调用方的使用
4. 可以考虑针对语言适配,例如实现 python 的 __getitem__ 、 __setitem__ 等魔术方法
# 12. 完整代码
#coding:utf-8
import random
class Node(object):
def __init__(self, height, key=None):
self.key = key
self.next = [None] * height
def height(self):
return len(self.next)
class SkipList(object):
def __init__(self):
self.head = Node(0, None) #头节点高度为0,不需要key
def randomHeight(self, p = 0.5):
height = 1
while random.uniform(0, 1) < p and self.head.height() >= height:
height += 1
return height
def insert(self, key):
node = Node(self.randomHeight(), key)
print node.height(), node.key
while node.height() > self.head.height():
self.head.next.append(None) #保证头节点的next数组覆盖所有层次的链表
update = self.getUpdateList(key)
if update[0].next[0] is not None and update[0].next[0].key == key:
return # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
for level in range(node.height()):
node.next[level] = update[level].next[level]
update[level].next[level] = node
def getUpdateList(self, key):
update = [None] * self.head.height()
x = self.head
for level in reversed(range(len(update))):
while x.next[level] is not None and x.next[level].key < key:
x = x.next[level]
update[level] = x
return update
def dump(self):
for i in range(self.head.height()):
sys.stdout.write('[H]')
x = self.head.next[0]
y = self.head.next[i]
while x is not None:
s = ' -> %s' % x.key
if x is y:
y = y.next[i]
else:
s = '-' * len(s)
x = x.next[0]
sys.stdout.write(s)
print ' -> <nil>'
print
def find(self, key):
update = self.getUpdateList(key)
if len(update) == 0:
return None
next0 = update[0].next[0]
if next0 is not None and next0.key == key:
return next0 # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
else:
return None
def remove(self, key):
update = self.getUpdateList(key)
for i, node in enumerate(update):
if node.next[i] is not None and node.next[i].key == key:
node.next[i] = node.next[i].next[i]
import random
class Node(object):
def __init__(self, height, key=None):
self.key = key
self.next = [None] * height
def height(self):
return len(self.next)
class SkipList(object):
def __init__(self):
self.head = Node(0, None) #头节点高度为0,不需要key
def randomHeight(self, p = 0.5):
height = 1
while random.uniform(0, 1) < p and self.head.height() >= height:
height += 1
return height
def insert(self, key):
node = Node(self.randomHeight(), key)
print node.height(), node.key
while node.height() > self.head.height():
self.head.next.append(None) #保证头节点的next数组覆盖所有层次的链表
update = self.getUpdateList(key)
if update[0].next[0] is not None and update[0].next[0].key == key:
return # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
for level in range(node.height()):
node.next[level] = update[level].next[level]
update[level].next[level] = node
def getUpdateList(self, key):
update = [None] * self.head.height()
x = self.head
for level in reversed(range(len(update))):
while x.next[level] is not None and x.next[level].key < key:
x = x.next[level]
update[level] = x
return update
def dump(self):
for i in range(self.head.height()):
sys.stdout.write('[H]')
x = self.head.next[0]
y = self.head.next[i]
while x is not None:
s = ' -> %s' % x.key
if x is y:
y = y.next[i]
else:
s = '-' * len(s)
x = x.next[0]
sys.stdout.write(s)
print ' -> <nil>'
def find(self, key):
update = self.getUpdateList(key)
if len(update) == 0:
return None
next0 = update[0].next[0]
if next0 is not None and next0.key == key:
return next0 # 0层总是包含所有元素;如果 update[0] 的下一个节点与key相等,则无需插入。
else:
return None
def remove(self, key):
update = self.getUpdateList(key)
for i, node in enumerate(update):
if node.next[i] is not None and node.next[i].key == key:
node.next[i] = node.next[i].next[i]
完。
Mar
2
我注意到过去几个月我司有些同学还在踩一个简单的分布式事务Case的坑,而这个坑在两年以前就已经有同学踩过了,这里简单解析一下这个case和合适的处理方案,供参考。
1. 踩过的坑
这个case有很多变种,先从我们在XX业务踩过的坑开始,大约是16年9月,核心业务需求是很简单的:在用户发起支付请求的时候,从用户的银行卡扣一笔钱。负责这个需求的同学是这么写的代码(去除其他业务逻辑的简化版):
乍一看好像是没有什么毛病,测试的case都顺利提供过,也没有人去仔细review这一小段代码,于是就这么上线了。但问题很快就暴露出来,PaySvr在支付成功以后尝试回调,XX业务系统报错”订单不存在”。查询线上日志发现,这笔订单在请求第三方支付通道时网络超时,Curl抛了timeout异常,导致支付记录被回滚。有心的同学可以自己复现一下这个问题,观察BUG的发生过程。
代码修复起来倒是很简单,在请求PaySvr之前提交事务,将支付请求安全落库即可。
把这个实现代入多个不同的业务下,还会衍生出更多问题,比如被动代扣业务,就可能因为重试导致用户被多次扣款,引起投诉(支付通道对投诉率的要求非常严格,甚至可能导致通道被关停);更严重的是放款业务,可能出现重复放款,给公司造成直接损失。据说某友商就是因为重复放款倒闭的,所以在实现类似业务时特别注意,考虑周全。
2. 归纳总结
我们往后退一步再审视这个case,这段简单的代码涉及了两个系统:XX业务系统(本地数据库)、PaySvr(外部系统)。可以看得出这段代码的本意,是期望将当前系统的业务和外部系统的业务,合并到一个事务里面,要么一起成功提交,要么一起失败回滚,从而保持两个系统的一致性。
之所以没能达到预期,直接原因是,在失败(异常)回滚的时候,只回滚了本地事务,而没有回滚远端系统的状态变化。按这个思路去考虑,似乎只要加一个 PaySvr::rollbackRequest($order->id) 好像就可以解决问题。但仔细想想就会发现远没这么简单,从业务上来说,如果已经给用户付款了,那实际上就是要给用户退款,而往往这时候是掉单(支付请求结果未知),我们又无法直接给用户退款;更极端一点,如果这个rollback请求也超时了呢,那本地可以rollback吗?
这就是分布式事务棘手的地方了,只靠这样的逻辑是无法保证跨系统的一致性的。解决这个问题的方法是引入两段式提交(2 Phase Commit,常常简写为2PC),其基本逻辑是,在两个系统分别完成业务,然后再分别提交。
例如我们上面的解决方案,实际上就是2PC的一个实现:我们把业务需求作为一整个事务,它可以拆成两个子事务(第三方支付通道完成代扣,在XX业务系统记录支付请求成功并修改相应业务状态),每个子事务又分成两个阶段:第一阶段,是在本地先记录支付请求(状态为待确认),并向第三方支付发出代扣请求(结果不确定);第二阶段,在确认第三方代扣成功以后,修改本地支付请求的状态修改为成功,或者是代扣结果为失败,本地支付请求状态记为失败。两个阶段都完成,这个事务才是真的完成了。
3. Case变种
仔细思考我们曾经实现过的需求,可能会在很多看似不起眼的地方发现分布式事务,例如我们在的存管匹配系统里面,就有这样一个Case。
由于与XX银行存管系统交互的延迟比较大,所以我们的匹配系统实现是异步的,匹配系统在撮合了资金和资产以后,会生成一条债权关系记录在本地,随后再发送到XX银行执行资金的划拨。为了提高执行的效率,我们希望在债权关系生成以后,尽快执行资金的划拨,因此我们会把资金划拨的指令通过LPush放进Redis的list里;List的另一端,那些使用BLPOP监听数据的worker会立刻被激活去执行。
如果没有仔细思考,代码可能会这么写:
在实际执行这段代码的时候,如果没有仔细测试(尤其是在有补单逻辑,捞出未执行成功的划拨指令再发送给银行),可能就不会发现,实际上有很多指令并不是马上被执行的,因为relation_id被送进list以后,worker马上就会读出来执行,但这时事务可能还没有提交。但这只是影响了业务的效率,还没有对业务的正确性产生影响。
为了修复这个问题,似乎可以这么做:把 [capital_id, project_id, amount] 发送到redis,worker直接取出执行,这样就不用从数据库读取relation,保证尽快将请求发送到银行。但如果因为某些原因,事务最终没有被提交呢?找银行rollback这些指令的执行,那就麻烦多了。
正确的做法是,在事务提交了以后,再lPush到Redis里:
注:foreach要放到try-catch后面。
最后想说,我相信有很多同学知道这个Case,或者就算不知道也不会犯这样的错误,因此也许会觉得没必要专门揪出来这样分享 —— 但“知识的诅咒”就是这样,“我会的东西都是简单的”,然而对于没有踩过坑的同学来说,其实都是宝贵的经验;另一方面,有些别人觉得简单的问题、踩过的坑,也许自己是不知道的。所以希望大家都能分享自己在工作学习中踩过的坑、解决过的问题,互相交流,互相提高。
1. 踩过的坑
这个case有很多变种,先从我们在XX业务踩过的坑开始,大约是16年9月,核心业务需求是很简单的:在用户发起支付请求的时候,从用户的银行卡扣一笔钱。负责这个需求的同学是这么写的代码(去除其他业务逻辑的简化版):
$dbTrans = $db->beginTransaction();
try {
$order = PayRequest::model()->newPayRequest(...); #在数据库中插入一条支付请求记录,状态为待支付
//其他业务改动
$result = PaySvr::pay($order->id, $order->amount); #请求PaySvr(或第三方支付通道)扣款
if ($result['code'] == PaySvr::E_SUCCESS) {
$order->setAsSucceeded();
} else {
$order->setAsPending();
}
$dbTrans->commit();
} catch (Exception $e) {
$dbTrans->rollback();
}
try {
$order = PayRequest::model()->newPayRequest(...); #在数据库中插入一条支付请求记录,状态为待支付
//其他业务改动
$result = PaySvr::pay($order->id, $order->amount); #请求PaySvr(或第三方支付通道)扣款
if ($result['code'] == PaySvr::E_SUCCESS) {
$order->setAsSucceeded();
} else {
$order->setAsPending();
}
$dbTrans->commit();
} catch (Exception $e) {
$dbTrans->rollback();
}
乍一看好像是没有什么毛病,测试的case都顺利提供过,也没有人去仔细review这一小段代码,于是就这么上线了。但问题很快就暴露出来,PaySvr在支付成功以后尝试回调,XX业务系统报错”订单不存在”。查询线上日志发现,这笔订单在请求第三方支付通道时网络超时,Curl抛了timeout异常,导致支付记录被回滚。有心的同学可以自己复现一下这个问题,观察BUG的发生过程。
代码修复起来倒是很简单,在请求PaySvr之前提交事务,将支付请求安全落库即可。
$dbTrans = $db->beginTransaction();
try {
$order = PayRequest::model()->newPayRequest(...);
//其他业务改动
$dbTrans->commit(); #先将支付请求落地
} catch (Exception $e) {
$dbTrans->rollback();
}
#再请求PaySvr
$result = PaySvr::pay($order->id, $order->amount);
#根据PaySvr结果修改支付请求和其他业务记录的状态
$dbTrans = $db->beginTransaction();
try {
if ($result['code'] == PaySvr::E_SUCCESS) {
$order->setAsSucceeded();
//其他业务改动
} else {
$order->setAsPending();
//其他业务改动
}
} catch (Exception $e) {
$dbTrans->rollback();
}
try {
$order = PayRequest::model()->newPayRequest(...);
//其他业务改动
$dbTrans->commit(); #先将支付请求落地
} catch (Exception $e) {
$dbTrans->rollback();
}
#再请求PaySvr
$result = PaySvr::pay($order->id, $order->amount);
#根据PaySvr结果修改支付请求和其他业务记录的状态
$dbTrans = $db->beginTransaction();
try {
if ($result['code'] == PaySvr::E_SUCCESS) {
$order->setAsSucceeded();
//其他业务改动
} else {
$order->setAsPending();
//其他业务改动
}
} catch (Exception $e) {
$dbTrans->rollback();
}
把这个实现代入多个不同的业务下,还会衍生出更多问题,比如被动代扣业务,就可能因为重试导致用户被多次扣款,引起投诉(支付通道对投诉率的要求非常严格,甚至可能导致通道被关停);更严重的是放款业务,可能出现重复放款,给公司造成直接损失。据说某友商就是因为重复放款倒闭的,所以在实现类似业务时特别注意,考虑周全。
2. 归纳总结
我们往后退一步再审视这个case,这段简单的代码涉及了两个系统:XX业务系统(本地数据库)、PaySvr(外部系统)。可以看得出这段代码的本意,是期望将当前系统的业务和外部系统的业务,合并到一个事务里面,要么一起成功提交,要么一起失败回滚,从而保持两个系统的一致性。
之所以没能达到预期,直接原因是,在失败(异常)回滚的时候,只回滚了本地事务,而没有回滚远端系统的状态变化。按这个思路去考虑,似乎只要加一个 PaySvr::rollbackRequest($order->id) 好像就可以解决问题。但仔细想想就会发现远没这么简单,从业务上来说,如果已经给用户付款了,那实际上就是要给用户退款,而往往这时候是掉单(支付请求结果未知),我们又无法直接给用户退款;更极端一点,如果这个rollback请求也超时了呢,那本地可以rollback吗?
这就是分布式事务棘手的地方了,只靠这样的逻辑是无法保证跨系统的一致性的。解决这个问题的方法是引入两段式提交(2 Phase Commit,常常简写为2PC),其基本逻辑是,在两个系统分别完成业务,然后再分别提交。
例如我们上面的解决方案,实际上就是2PC的一个实现:我们把业务需求作为一整个事务,它可以拆成两个子事务(第三方支付通道完成代扣,在XX业务系统记录支付请求成功并修改相应业务状态),每个子事务又分成两个阶段:第一阶段,是在本地先记录支付请求(状态为待确认),并向第三方支付发出代扣请求(结果不确定);第二阶段,在确认第三方代扣成功以后,修改本地支付请求的状态修改为成功,或者是代扣结果为失败,本地支付请求状态记为失败。两个阶段都完成,这个事务才是真的完成了。
3. Case变种
仔细思考我们曾经实现过的需求,可能会在很多看似不起眼的地方发现分布式事务,例如我们在的存管匹配系统里面,就有这样一个Case。
由于与XX银行存管系统交互的延迟比较大,所以我们的匹配系统实现是异步的,匹配系统在撮合了资金和资产以后,会生成一条债权关系记录在本地,随后再发送到XX银行执行资金的划拨。为了提高执行的效率,我们希望在债权关系生成以后,尽快执行资金的划拨,因此我们会把资金划拨的指令通过LPush放进Redis的list里;List的另一端,那些使用BLPOP监听数据的worker会立刻被激活去执行。
如果没有仔细思考,代码可能会这么写:
#匹配系统
function matcher() {
$dbTrans = $db->beginTransaction();
try {
foreach (matchCapitalAndProject() as $match_result) {
list($capital_id, $project_id, $amount) = $match_result;
$relation = Relation::model()->create($capital_id, $project_id, $amount);
$redis->lPush($relation->id);
}
$dbTrans->commit();
} catch (Exception $e) {
$dbTrans->rollback();
}
}
#Worker
function Worker() {
while (true) {
$id = $redis->brPop();
$relation = Relation::model()->findByPk($id);
if ($relation) {
BankApi::invest($relation->capital_id, $relation->project_id, $amount);
}
}
}
function matcher() {
$dbTrans = $db->beginTransaction();
try {
foreach (matchCapitalAndProject() as $match_result) {
list($capital_id, $project_id, $amount) = $match_result;
$relation = Relation::model()->create($capital_id, $project_id, $amount);
$redis->lPush($relation->id);
}
$dbTrans->commit();
} catch (Exception $e) {
$dbTrans->rollback();
}
}
#Worker
function Worker() {
while (true) {
$id = $redis->brPop();
$relation = Relation::model()->findByPk($id);
if ($relation) {
BankApi::invest($relation->capital_id, $relation->project_id, $amount);
}
}
}
在实际执行这段代码的时候,如果没有仔细测试(尤其是在有补单逻辑,捞出未执行成功的划拨指令再发送给银行),可能就不会发现,实际上有很多指令并不是马上被执行的,因为relation_id被送进list以后,worker马上就会读出来执行,但这时事务可能还没有提交。但这只是影响了业务的效率,还没有对业务的正确性产生影响。
为了修复这个问题,似乎可以这么做:把 [capital_id, project_id, amount] 发送到redis,worker直接取出执行,这样就不用从数据库读取relation,保证尽快将请求发送到银行。但如果因为某些原因,事务最终没有被提交呢?找银行rollback这些指令的执行,那就麻烦多了。
正确的做法是,在事务提交了以后,再lPush到Redis里:
#匹配系统
function matcher() {
$arr_relation = [];
$dbTrans = $db->beginTransaction();
try {
foreach (matchCapitalAndProject() as $match_result) {
list($capital_id, $project_id, $amount) = $match_result;
$relation = Relation::model()->create($capital_id, $project_id, $amount);
$arr_relation[] = $relation;
}
$dbTrans->commit();
} catch (Exception $e) {
$dbTrans->rollback();
}
foreach ($arr_relation as $relation) {
$redis->lPush($relation->id);
}
}
function matcher() {
$arr_relation = [];
$dbTrans = $db->beginTransaction();
try {
foreach (matchCapitalAndProject() as $match_result) {
list($capital_id, $project_id, $amount) = $match_result;
$relation = Relation::model()->create($capital_id, $project_id, $amount);
$arr_relation[] = $relation;
}
$dbTrans->commit();
} catch (Exception $e) {
$dbTrans->rollback();
}
foreach ($arr_relation as $relation) {
$redis->lPush($relation->id);
}
}
注:foreach要放到try-catch后面。
最后想说,我相信有很多同学知道这个Case,或者就算不知道也不会犯这样的错误,因此也许会觉得没必要专门揪出来这样分享 —— 但“知识的诅咒”就是这样,“我会的东西都是简单的”,然而对于没有踩过坑的同学来说,其实都是宝贵的经验;另一方面,有些别人觉得简单的问题、踩过的坑,也许自己是不知道的。所以希望大家都能分享自己在工作学习中踩过的坑、解决过的问题,互相交流,互相提高。
Jul
17
前几天讨论遇到一个涉及区间覆盖的数据统计,发现很适合使用线段树来解决,于是重新回顾了一下这个好几年前学过的东西,凭着残存的理解,好了好久才勉强写了出来,感觉自己确实是没有搞算法的天赋,在边界处理的时候磕磕碰碰的,需要改好几次才能写对,不够干净利落。
不过能从繁杂的业务中抽出来写写纯粹的数据结构和算法,有点回到学校的状态,感觉也蛮不错。
以前在学校折腾算法的时候,从yyt同学的分享的ppt学到了这个数据结构,印象比较深的是,ppt上说,对于一个长度是 x 的线段,使用数组(元素 i 的左右节点分别是 2*i 和 2*i+1 )来记录的话,需要大小约为 3*x 的数组,但是在实际做题的时候却发现越界了,后来仔细去挖这个地方,才发现,其实应该是找到一个 y = 2^n 满足 2^(n-1) < x <= 2^n,所需的数组长度为 2y 。
这次是先写了一个C++的class(偷懒用的struct),配合一些c-style的函数,让python用ctypes载入使用。然后兴起写了个Python的版本对比,跑了个简单的case,发现性能居然相差200+倍,做了一些改进,才领悟到,对于python来说,其实用数组建树比起直接用对象指针关联建树并没有太大优势,而用对象建树的好处是,如果不是极端情况,可以lazy load左右子树(当然,数组也可以lazy initialization子节点,但不能减少内存占用)。改起来也不难,于是就验证了一下,效果相当好,甚至比C++还快(因为测试case太简单,几乎没有展开子树),此外这种方式节点的数量可以减少到2 * x - 1(但是相应地每个节点需要增加指针)。
另外遇到一个问题是迭代,Python的迭代器如果使用Generator语法(yield),写起来和用起来都特别自然,可是到C++就完全不一样了,形式上想要达到类似的效果比较累,保存和还原现场比较辛苦(不过这个case还好),试着实现了一个版本,但是需要遍历所有的节点感觉不太好,后来还是改成了偷懒的写法(直接生成整个结果集,在结果集上迭代)。
最后,不成熟的小代码放在了这里:https://github.com/felix021/mycodes/tree/master/segtree
[update] 到数据集上实际跑了一下,C++版还是跑赢了几倍的速度,这还是没有做lazy init的情况,回头抽空再写个版本验证一下吧~
不过能从繁杂的业务中抽出来写写纯粹的数据结构和算法,有点回到学校的状态,感觉也蛮不错。
以前在学校折腾算法的时候,从yyt同学的分享的ppt学到了这个数据结构,印象比较深的是,ppt上说,对于一个长度是 x 的线段,使用数组(元素 i 的左右节点分别是 2*i 和 2*i+1 )来记录的话,需要大小约为 3*x 的数组,但是在实际做题的时候却发现越界了,后来仔细去挖这个地方,才发现,其实应该是找到一个 y = 2^n 满足 2^(n-1) < x <= 2^n,所需的数组长度为 2y 。
这次是先写了一个C++的class(偷懒用的struct),配合一些c-style的函数,让python用ctypes载入使用。然后兴起写了个Python的版本对比,跑了个简单的case,发现性能居然相差200+倍,做了一些改进,才领悟到,对于python来说,其实用数组建树比起直接用对象指针关联建树并没有太大优势,而用对象建树的好处是,如果不是极端情况,可以lazy load左右子树(当然,数组也可以lazy initialization子节点,但不能减少内存占用)。改起来也不难,于是就验证了一下,效果相当好,甚至比C++还快(因为测试case太简单,几乎没有展开子树),此外这种方式节点的数量可以减少到2 * x - 1(但是相应地每个节点需要增加指针)。
另外遇到一个问题是迭代,Python的迭代器如果使用Generator语法(yield),写起来和用起来都特别自然,可是到C++就完全不一样了,形式上想要达到类似的效果比较累,保存和还原现场比较辛苦(不过这个case还好),试着实现了一个版本,但是需要遍历所有的节点感觉不太好,后来还是改成了偷懒的写法(直接生成整个结果集,在结果集上迭代)。
最后,不成熟的小代码放在了这里:https://github.com/felix021/mycodes/tree/master/segtree
[update] 到数据集上实际跑了一下,C++版还是跑赢了几倍的速度,这还是没有做lazy init的情况,回头抽空再写个版本验证一下吧~
Oct
29
转置二维数组:
utf-8字符串转为utf-8字符数组:
按显示宽度截取utf-8字符串
让进程在后台运行(detached process),出乎意料地简单
function transpose($array) {
array_unshift($array, null);
return call_user_func_array('array_map', $array);
}
array_unshift($array, null);
return call_user_func_array('array_map', $array);
}
utf-8字符串转为utf-8字符数组:
function utf8_str2arr($str)
{
preg_match_all("/./u", $str, $arr);
return $arr[0];
}
{
preg_match_all("/./u", $str, $arr);
return $arr[0];
}
按显示宽度截取utf-8字符串
function substr_width($str, $start, $width)
{
$arr = utf8_str2arr($str);
$arr_ret = [];
$i = 0;
while ($width > 0 and $i < count($arr))
{
$arr_ret[] = $arr[$start + $i];
if (strlen($arr_ret[$i]) == 1) //ascii,width=1
$width -= 1;
else
$width -= 2;
$i++;
}
if ($width < 0)
array_pop($arr_ret);
return join('', $arr_ret);
}
{
$arr = utf8_str2arr($str);
$arr_ret = [];
$i = 0;
while ($width > 0 and $i < count($arr))
{
$arr_ret[] = $arr[$start + $i];
if (strlen($arr_ret[$i]) == 1) //ascii,width=1
$width -= 1;
else
$width -= 2;
$i++;
}
if ($width < 0)
array_pop($arr_ret);
return join('', $arr_ret);
}
让进程在后台运行(detached process),出乎意料地简单
pclose(popen("nohup $cmd &", 'r'));
Jul
30
简单地说,MPI 就是个并行计算框架,模型也很直接——就是多进程。和hadoop不同,它不提供计算任务的map和reduce,只提供了一套通信接口,需要程序员来完成这些任务;它也不提供冗余容错等机制,完全依赖于其下层的可靠性。但是因为把控制权几乎完全交给了程序员,所以有很大的灵活性,可以最大限度地榨取硬件性能。超级计算机上的运算任务,基本上都是使用MPI来开发的。
~ 下载编译安装:
现在貌似一般都用MPICH,开源的MPI库,可以从这里获取: http://www.mpich.org/ ,现在的最新版本是3.0.4,编译安装过程可以参考安装包里的README的说明,基本步骤如下(万恶的configure):
$ wget http://www.mpich.org/static/downloads/3.0.4/mpich-3.0.4.tar.gz
$ tar zxf mpich-3.0.4.tar.gz
$ cd mpich-3.0.4
$ mkdir ~/mpich
$ ./configure --prefix=$HOME/mpich --disable-f77 --disable-fc 2>&1 | tee c.txt #我禁用了fortran的支持
$ make -j4 2>&1 | tee m.txt
$ make install 2>&1 | tee i.txt
$ echo 'export PATH=$PATH:~/mpich/bin' >> ~/.bashrc
下面给出三个例子,参考教程:http://wenku.baidu.com/view/ee8bf3390912a216147929f3.html (注:22页有BUG,它把 MPI_Comm_XXX 错写成了 MPI_Common_xxx //包括全大写版本,共四处),给出了MPI框架中最常用、最基础的6个API的例子。更复杂的API可以参考mpich的手册。这些例子只是简单地演示了MPI框架的使用;实际上在使用MPI开发并行计算的软件时,还需要考虑到很多方面的问题,这里就不展开说了(其实真相是我也不会-.-,有兴趣的话可以请教 @momodi 和 @dumbear 两位)。
1. 最简单的:Hello world
代码如下: hello.c
编译:
$ mpicc -o hello hello.c
运行:
$ mpiexec -n 4 ./hello
Hello world!
Hello world!
Hello world!
Hello world!
可以看到这里启动了4个进程。注意 -n 和 4 之间一定要有空格,否则会报错。
2. 进程间通信
MPI最基本的通信接口是 MPI_Send/MPI_Recv:
编译运行:
$ mpicc comm.c
$ mpiexec -n 4 ./a.out
I'm 0 of 4
from 1: hello from 1
from 2: hello from 2
I'm 1 of 4
I'm 2 of 4
from 3: hello from 3
I'm 3 of 4
3. 来个复杂点的:数数前1亿个自然数里有几个 雷劈数
代码后附,答案是97(真少),不过这不是重点,重点是MPI对硬件的利用率是怎样 :D
测试机器是 16核 AMD Opteron 6128HE @2GHz,32G内存
单进程(无MPI版本):56.9s
4进程:15.3s
8进程:7.85s
12进程:5.35s
考虑到16核跑满可能会受到其他进程的影响(性能不稳定,4.2~4.9s),这个数据就不列进来比较了。
可以看出来,在这个例子里,因为通信、同步只有在计算完之后才有那么一点点,所以在SMP架构下,耗费的时间基本上是跟进程数成反比的,说明MPI框架对硬件性能的利用率还是相当高的。
具体代码如下:
~ 下载编译安装:
现在貌似一般都用MPICH,开源的MPI库,可以从这里获取: http://www.mpich.org/ ,现在的最新版本是3.0.4,编译安装过程可以参考安装包里的README的说明,基本步骤如下(万恶的configure):
引用
$ wget http://www.mpich.org/static/downloads/3.0.4/mpich-3.0.4.tar.gz
$ tar zxf mpich-3.0.4.tar.gz
$ cd mpich-3.0.4
$ mkdir ~/mpich
$ ./configure --prefix=$HOME/mpich --disable-f77 --disable-fc 2>&1 | tee c.txt #我禁用了fortran的支持
$ make -j4 2>&1 | tee m.txt
$ make install 2>&1 | tee i.txt
$ echo 'export PATH=$PATH:~/mpich/bin' >> ~/.bashrc
下面给出三个例子,参考教程:http://wenku.baidu.com/view/ee8bf3390912a216147929f3.html (注:22页有BUG,它把 MPI_Comm_XXX 错写成了 MPI_Common_xxx //包括全大写版本,共四处),给出了MPI框架中最常用、最基础的6个API的例子。更复杂的API可以参考mpich的手册。这些例子只是简单地演示了MPI框架的使用;实际上在使用MPI开发并行计算的软件时,还需要考虑到很多方面的问题,这里就不展开说了(其实真相是我也不会-.-,有兴趣的话可以请教 @momodi 和 @dumbear 两位)。
1. 最简单的:Hello world
代码如下: hello.c
#include <stdio.h>
#include <mpi.h>
int main(int argc, char *argv[])
{
MPI_Init(&argc, &argv); //初始化MPI环境
printf("Hello world!\n");
MPI_Finalize(); //结束MPI环境
return 0;
}
#include <mpi.h>
int main(int argc, char *argv[])
{
MPI_Init(&argc, &argv); //初始化MPI环境
printf("Hello world!\n");
MPI_Finalize(); //结束MPI环境
return 0;
}
编译:
$ mpicc -o hello hello.c
运行:
$ mpiexec -n 4 ./hello
Hello world!
Hello world!
Hello world!
Hello world!
可以看到这里启动了4个进程。注意 -n 和 4 之间一定要有空格,否则会报错。
2. 进程间通信
MPI最基本的通信接口是 MPI_Send/MPI_Recv:
#include <stdio.h>
#include <mpi.h>
int main(int argc, char *argv[])
{
int myid, numprocs, source, msg_tag = 0;
char msg[100];
MPI_Status status;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &numprocs); //共启动几个进程
MPI_Comm_rank(MPI_COMM_WORLD, &myid); //当前进程的编号(0~n-1)
printf("I'm %d of %d\n", myid, numprocs);
if (myid != 0)
{
int len = sprintf(msg, "hello from %d", myid);
MPI_Send(msg, len, MPI_CHAR, 0, msg_tag, MPI_COMM_WORLD); //向id=0的进程发送信息
}
else
{
for (source = 1; source < numprocs; source++)
{
//从id=source的进程接受消息
MPI_Recv(msg, 100, MPI_CHAR, source, msg_tag, MPI_COMM_WORLD, &status);
printf("from %d: %s\n", source, msg);
}
}
MPI_Finalize();
return 0;
}
#include <mpi.h>
int main(int argc, char *argv[])
{
int myid, numprocs, source, msg_tag = 0;
char msg[100];
MPI_Status status;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &numprocs); //共启动几个进程
MPI_Comm_rank(MPI_COMM_WORLD, &myid); //当前进程的编号(0~n-1)
printf("I'm %d of %d\n", myid, numprocs);
if (myid != 0)
{
int len = sprintf(msg, "hello from %d", myid);
MPI_Send(msg, len, MPI_CHAR, 0, msg_tag, MPI_COMM_WORLD); //向id=0的进程发送信息
}
else
{
for (source = 1; source < numprocs; source++)
{
//从id=source的进程接受消息
MPI_Recv(msg, 100, MPI_CHAR, source, msg_tag, MPI_COMM_WORLD, &status);
printf("from %d: %s\n", source, msg);
}
}
MPI_Finalize();
return 0;
}
编译运行:
$ mpicc comm.c
$ mpiexec -n 4 ./a.out
I'm 0 of 4
from 1: hello from 1
from 2: hello from 2
I'm 1 of 4
I'm 2 of 4
from 3: hello from 3
I'm 3 of 4
3. 来个复杂点的:数数前1亿个自然数里有几个 雷劈数
代码后附,答案是97(真少),不过这不是重点,重点是MPI对硬件的利用率是怎样 :D
测试机器是 16核 AMD Opteron 6128HE @2GHz,32G内存
单进程(无MPI版本):56.9s
4进程:15.3s
8进程:7.85s
12进程:5.35s
考虑到16核跑满可能会受到其他进程的影响(性能不稳定,4.2~4.9s),这个数据就不列进来比较了。
可以看出来,在这个例子里,因为通信、同步只有在计算完之后才有那么一点点,所以在SMP架构下,耗费的时间基本上是跟进程数成反比的,说明MPI框架对硬件性能的利用率还是相当高的。
具体代码如下:
#include <stdio.h>
#include <mpi.h>
int is_lp(long long x)
{
long long t = x * x, i = 10;
while (i < t)
{
long long l = t / i, r = t % i;
if (l + r == x)
return 1;
i *= 10;
}
return 0;
}
int main(int argc, char *argv[])
{
int myid, numprocs, source;
const int N = 100000000;
MPI_Status status;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
MPI_Comm_size(MPI_COMM_WORLD, &numprocs);
printf("I'm %d of %d\n", myid, numprocs);
int start = myid * (N / numprocs), stop = (myid + 1) * (N / numprocs);
if (myid == numprocs - 1)
stop = N;
printf("start from %d to %d\n", start, stop);
int ans = 0, i;
for (i = start; i < stop; i++)
if (is_lp(i))
ans += 1;
printf("%d finished calculation with %d numbers\n", myid, ans);
if (myid != 0)
{
MPI_Send(&ans, 1, MPI_INT, 0, 0, MPI_COMM_WORLD);
}
else
{
int tmp;
for (source = 1; source < numprocs; source++)
{
MPI_Recv(&tmp, 1, MPI_INT, source, 0, MPI_COMM_WORLD, &status);
printf("from %d: %d\n", source, tmp);
ans += tmp;
}
printf("final ans: %d\n", ans);
}
MPI_Finalize();
return 0;
}
#include <mpi.h>
int is_lp(long long x)
{
long long t = x * x, i = 10;
while (i < t)
{
long long l = t / i, r = t % i;
if (l + r == x)
return 1;
i *= 10;
}
return 0;
}
int main(int argc, char *argv[])
{
int myid, numprocs, source;
const int N = 100000000;
MPI_Status status;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
MPI_Comm_size(MPI_COMM_WORLD, &numprocs);
printf("I'm %d of %d\n", myid, numprocs);
int start = myid * (N / numprocs), stop = (myid + 1) * (N / numprocs);
if (myid == numprocs - 1)
stop = N;
printf("start from %d to %d\n", start, stop);
int ans = 0, i;
for (i = start; i < stop; i++)
if (is_lp(i))
ans += 1;
printf("%d finished calculation with %d numbers\n", myid, ans);
if (myid != 0)
{
MPI_Send(&ans, 1, MPI_INT, 0, 0, MPI_COMM_WORLD);
}
else
{
int tmp;
for (source = 1; source < numprocs; source++)
{
MPI_Recv(&tmp, 1, MPI_INT, source, 0, MPI_COMM_WORLD, &status);
printf("from %d: %d\n", source, tmp);
ans += tmp;
}
printf("final ans: %d\n", ans);
}
MPI_Finalize();
return 0;
}
Jul
29
nth_element是一个用烂了的面试题,09年的时候我也曾经跑过一点数据(回头一看好像有点不太对,那时候CPU有那么慢吗?应该是当时没有把读取数据的时间给去掉吧),昨天看到 从一道笔试题谈算法优化(上) 和 从一道笔试题谈算法优化(下) 这个系列,里面提到了从简单到复杂的6种算法,根据作者的说法,经过各种优化以后solution6达到了相当的效率。受到作者启发,我也想到了一种算法,于是花了些时间,把常用的几种算法实现了进行对比(包括对比stl里的nth_element)。
回到 nth_element 的问题:从 n 个数里头,找出其中的 top k (top可以是找最大 也可以是找最小)。简单起见,后面都以找最小为例。
那篇文章里的后面的3种算法和我的算法都是基于它的solution3进行优化的,所以先介绍一下solution3:
1. 将 n 个数的前 k 个拷贝出来
2. 找出这 k 个数的最大值 m
3. 对于每一个剩下的 n - k 个数 i :如果 i 小于 m ,将 m 替换为 i ,然后再找出新的 k 个数中的最大值 m
4. 返回过滤后得到的 k 个数
虽然第3步的判断条件成立时,这一步是 O(k) 的复杂度,但是在大部分情况下,这个判断条件都不成立,所以它需要执行的次数很少(这种效果越往后越明显,因为 m 的值越来越小,可以滤掉更多的数)。但是对于极端情况(或者接近极端的情况)——如果数组完全是降序排列的话,那么这个条件每次都会成立,会导致算法退化成 O(n * m) ,性能就不可接受了。
solution4~6的优化我觉得读起来有点晦涩,有兴趣的同学可以自己去看,这里不展开说了。它们都存在算法退化的情况。
我的算法可能理解起来要简单些,基本思路是偷懒:把缓冲区设置为 2 * k ,在扫数组的时候,如果这个数比当前保留的最大值还要小,就把它塞进缓冲区,直到缓冲区塞满了,再排序、取最大值、删多余的数。
1. 开辟一个 2 × k 的临时数组 r
2. 将 N[1..k] 拷贝到 r[1..k],并记录 r 的末尾 e = k + 1
3. 找出 r[1..k] 中的最大值 m
4. 对于每一个剩下的 n - k 个数 i:
4.1 如果 i 小于 m,将 m 塞到 r 的末尾 r[e];e = e + 1
4.1.1 如果 e == 2×k(r被塞满了),对 r 进行排序,得出前k个数中的最大值;抛弃后面的k个数,e = k + 1
5. 对 r[1..e] 进行排序
6. 返回 r[1..k]
这个算法的实际效果很好,在 n = 1亿、k = 1万 的情况下,对于随机的n个数,4.1条件成立的次数通常只有十几次,几乎可以忽略,因此大约只需要扫描整个数组时间的2~3倍即可。
但是这个算法同样存在退化的问题:如果全都是倒序排列的话,也变成O(n*m)了。幸而解决方案也很简单:对这个数组进行 n / 100 次的随机化(考虑到随机化耗时也比较大,100这个数字是试了几次以后发现的合适值,不一定最优,但是都差不多了):
p.s. 很不幸的是随机化这个方法对于solution3来说虽然提升也很明显,但是远远不够。solution4~6也不太行。
这里给出随机情况下和完全倒序情况下的性能对比吧(ubuntu 12.04 x86_64,i5 2400@3.1G):
补充说明一下,nth_element算法在面试的时候可能给出的 n 会更大,以至于在内存中存不下,这时候通常会认为用堆来实现最好——因为只要O(k)的空间就可以搞定,而且最差时间复杂度是O(logk * n),但是要注意logk的常数实际上是很大的,所以你可以看到前面的数据里 堆算法 的效率并不是最好的。事实上这里给出的所有算法(不考虑随机化的话)也都只需要O(k)的空间;如果需要随机化的话也很简单,分段处理,只要保证每段能在内存中保存下来就行了。
最后上代码存档(为了方便测试用了些全局变量,看起来可能有点挫):
回到 nth_element 的问题:从 n 个数里头,找出其中的 top k (top可以是找最大 也可以是找最小)。简单起见,后面都以找最小为例。
那篇文章里的后面的3种算法和我的算法都是基于它的solution3进行优化的,所以先介绍一下solution3:
1. 将 n 个数的前 k 个拷贝出来
2. 找出这 k 个数的最大值 m
3. 对于每一个剩下的 n - k 个数 i :如果 i 小于 m ,将 m 替换为 i ,然后再找出新的 k 个数中的最大值 m
4. 返回过滤后得到的 k 个数
虽然第3步的判断条件成立时,这一步是 O(k) 的复杂度,但是在大部分情况下,这个判断条件都不成立,所以它需要执行的次数很少(这种效果越往后越明显,因为 m 的值越来越小,可以滤掉更多的数)。但是对于极端情况(或者接近极端的情况)——如果数组完全是降序排列的话,那么这个条件每次都会成立,会导致算法退化成 O(n * m) ,性能就不可接受了。
solution4~6的优化我觉得读起来有点晦涩,有兴趣的同学可以自己去看,这里不展开说了。它们都存在算法退化的情况。
我的算法可能理解起来要简单些,基本思路是偷懒:把缓冲区设置为 2 * k ,在扫数组的时候,如果这个数比当前保留的最大值还要小,就把它塞进缓冲区,直到缓冲区塞满了,再排序、取最大值、删多余的数。
1. 开辟一个 2 × k 的临时数组 r
2. 将 N[1..k] 拷贝到 r[1..k],并记录 r 的末尾 e = k + 1
3. 找出 r[1..k] 中的最大值 m
4. 对于每一个剩下的 n - k 个数 i:
4.1 如果 i 小于 m,将 m 塞到 r 的末尾 r[e];e = e + 1
4.1.1 如果 e == 2×k(r被塞满了),对 r 进行排序,得出前k个数中的最大值;抛弃后面的k个数,e = k + 1
5. 对 r[1..e] 进行排序
6. 返回 r[1..k]
这个算法的实际效果很好,在 n = 1亿、k = 1万 的情况下,对于随机的n个数,4.1条件成立的次数通常只有十几次,几乎可以忽略,因此大约只需要扫描整个数组时间的2~3倍即可。
但是这个算法同样存在退化的问题:如果全都是倒序排列的话,也变成O(n*m)了。幸而解决方案也很简单:对这个数组进行 n / 100 次的随机化(考虑到随机化耗时也比较大,100这个数字是试了几次以后发现的合适值,不一定最优,但是都差不多了):
void rand_factor()
{
//randomization
const int factor = 100;
int r, t, i;
for (i = n / factor; i > 0; i--)
{
r = rand() % n;
t = a[0], a[0] = a[r], a[r] = t;
}
}
{
//randomization
const int factor = 100;
int r, t, i;
for (i = n / factor; i > 0; i--)
{
r = rand() % n;
t = a[0], a[0] = a[r], a[r] = t;
}
}
p.s. 很不幸的是随机化这个方法对于solution3来说虽然提升也很明显,但是远远不够。solution4~6也不太行。
这里给出随机情况下和完全倒序情况下的性能对比吧(ubuntu 12.04 x86_64,i5 2400@3.1G):
引用
随机数据
========
随机化生成数据: 7.087s
遍历耗时: 0.053s
1/100随机化耗时: 0.064s
STL的nth_element: 0.841s
最大堆+随机化: 0.162s
Solution3+随机化: 3.121s
solution4+随机化: 2.825s
solution6+随机化: 0.249s
我的算法+随机化: 0.163s
倒序数据
========
无随机化生成数据: 0.150s
遍历耗时: 0.052s
1/100随机化耗时: 0.067s
STL的nth_element: 1.420s
最大堆+随机化: 0.301s
Solution3+随机化: 67.552s
solution4+随机化: 66.405s
solution6+随机化: 2.388s
我的算法+随机化: 0.170s
========
随机化生成数据: 7.087s
遍历耗时: 0.053s
1/100随机化耗时: 0.064s
STL的nth_element: 0.841s
最大堆+随机化: 0.162s
Solution3+随机化: 3.121s
solution4+随机化: 2.825s
solution6+随机化: 0.249s
我的算法+随机化: 0.163s
倒序数据
========
无随机化生成数据: 0.150s
遍历耗时: 0.052s
1/100随机化耗时: 0.067s
STL的nth_element: 1.420s
最大堆+随机化: 0.301s
Solution3+随机化: 67.552s
solution4+随机化: 66.405s
solution6+随机化: 2.388s
我的算法+随机化: 0.170s
补充说明一下,nth_element算法在面试的时候可能给出的 n 会更大,以至于在内存中存不下,这时候通常会认为用堆来实现最好——因为只要O(k)的空间就可以搞定,而且最差时间复杂度是O(logk * n),但是要注意logk的常数实际上是很大的,所以你可以看到前面的数据里 堆算法 的效率并不是最好的。事实上这里给出的所有算法(不考虑随机化的话)也都只需要O(k)的空间;如果需要随机化的话也很简单,分段处理,只要保证每段能在内存中保存下来就行了。
最后上代码存档(为了方便测试用了些全局变量,看起来可能有点挫):