深入探究linq原理——如何在自己的语言里实现linq

Posted by pzque on 2018-01-05     

坑挖的有点多,最近打算填一个:给scala加上linq。

在spark RDD和DataFrame上直接用岂不是美滋滋。

用过几次c#,linq还是非常直观的,很喜欢这个设计。不过现在都忘的差不多了,再来回顾一下linq到底是个什么东西。

overview

先不讲类型签名扩展方法这些细节,我们从从官网给的最基本的例子开始,来一个整体的概览,看看linq到底是什么:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
using System;
using System.Linq;

class IntroToLINQ
{
static void Main()
{
// The Three Parts of a LINQ Query:
// 1. Data source.
int[] numbers = new int[7] { 0, 1, 2, 3, 4, 5, 6 };

// 2. Query creation.
// numQuery is an IEnumerable<int>
var numQuery =
from num in numbers
where (num % 2) == 0
select num;

// 3. Query execution.
foreach (int num in numQuery)
{
Console.Write("{0,1} ", num);
}
}
}

linq,语言集成查询,就是语法上支持类sql的查询语法,对于熟悉sql查询的广大coder,可读性比链式方法调用不知高到哪里去了。

但这也只是一层语法糖而已,在编译后还是要转化成方法调用。

比如上面的查询等价于:

1
var numQuery = numbers.Where(num => num %2 ==0).Select(num => num);

当然因为select的数据没变,这个Select调用完全可以省略。而Where就相当于filter,Select就相当于map,这些简单的操作都非常容易理解。linq支持的其他join、aggerate等操作符,同样是写好的一堆方法,比较容易理解。

比较特殊的是:当多个from串联在一起时,事情就变得稍微有些复杂。下面具体介绍一下这种情况。

from和SelectMany

还是先给例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
using System;
using System.Linq;

class IntroToLINQ
{
static void Main ()
{
int[] numbers = new int[7] { 0, 1, 2, 3, 4, 5, 6 };

var numQuery =
from num0 in numbers
from num1 in numbers
where num0 + num1 > 11
select new {num0, num1};

foreach (var num in numQuery) {
Console.Write ("{0,1} ", num);
}
}
}

只看查询语句,它干了什么非常容易理解:在numbers和numbers(自己和自己)的笛卡尔积中选择出两数之和大于11的组合,输出是:

1
2
3
{ num0 = 5, num1 = 6 }
{ num0 = 6, num1 = 5 }
{ num0 = 6, num1 = 6 }

提一些稍微扩展的内容,看不懂没关系,可以跳过,如果有兴趣搞懂的话可以了解一下haskell的Monad:

看着一段scala的代码:

1
2
3
4
5
6
7
8
9
10
11
object Main {
def main(args: Array[String]): Unit = {
val numbers = List(0, 1, 2, 3, 4, 5, 6)

val queryResult =
for (num0 <- numbers; num1 <- numbers
if num0 + num1 > 10) yield (num0, num1)

queryResult.foreach(println)
}
}

输出结果:

1
2
3
(5,6)
(6,5)
(6,6)

然后是一段可以在ghci里执行的haskell代码:

1
2
3
4
5
6
7
8
9
numbers = [0, 1, 2, 3, 4, 5, 6]

queryResult = do {
num0<-numbers;
num1<-numbers;
if num0+num1>10
then [(num0,num1)]
else []
}

haskell没有foreach,手动看一下结果

1
> queryResult

得到

1
=> [(5,6),(6,5),(6,6)]

可以发现得到的结果是一模一样的,事实上这三者本来就是一回事。c#串联的from查询表达式、scala的for语法、haskell的do notation,本质上都是一个东西,都是一层语法糖,把对Monad的操作串联起来,最后都会翻译成方法/函数调用。其底层方法/函数分别是SelectManyflatMap>>=。至于Monad是个什么东西,又是另一个话题了,说大不大说小不小的话题。。。

还是回到正题,我们来看一下数组这些容器的SelectMany方法的具体功能。

接受一个函数的SelectMany

先看一个最普通的SelectMany例子,这也是它原本的语义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int[] list = { 1, 2, 3 };

Func<int, string[]> selector = x =>
{
var s = String.Format("hi{0}", x);
return new[] { s, s, s };
};

var result = list.SelectMany(selector);

foreach (var e in result)
{
Console.Write("{0},", e);
}

输出是:

1
hi1,hi1,hi1,hi2,hi2,hi2,hi3,hi3,hi3,

很容易理解,SelectMany接受一个函数,这个函数对容器的每个元素应用一遍,每次都返回一个新的容器。比如这里,selector接受数字N然后返回{"hiN","hiN","hiN"}这个列表,对每个元素调用一遍就会得到这个大列表{ {"hi1","hi1","hi1"},{"hi2","hi2","hi2"},{"hi3","hi3","hi3"} },最后把这个大列表拍平就得到最终的结果:{hi1,hi1,hi1,hi2,hi2,hi2,hi3,hi3,hi3}。当然实际的实现不一定这样来,但是这样理解就对了。

让我们看一下这个函数的签名:

1
2
3
4
public static IEnumerable<TResult> SelectMany<TSource, TResult>(
this IEnumerable<TSource> source,
Func<TSource, IEnumerable<TResult>> selector
);

和例子的类型是对应的。

没接触过c#的话需要注意两点:

  1. 第一个参数source前面有个this,这是扩展方法的语法,source.SelectMany(selector)就相当于SelectMany(source,selector)
  2. IEnumerable<T>是一个接口,c#的数组都实现了这个接口。所以int[]满足IEnumerable<int>string[]满足IEnumerable<string>

这个SelectMany有什么用呢?

前面说过了,它能把一系列操作串联起来。我们再来看一个例子,求两个列表的笛卡尔积,比如对于[1,2]he[3,4]我们怎么得到[(1,3),(1,4),(2,3),(2,4)]

1
2
3
4
5
6
7
8
9
10
11
12
13
int[] alist = { 1, 2 };
int[] blist = { 3, 4 };

var result = alist.SelectMany(
a => blist.Select(
b => new { a, b }
)
);

foreach (var e in result)
{
Console.Write("{0},", e);
}

输出是:{ a = 1, b = 3 },{ a = 1, b = 4 },{ a = 2, b = 3 },{ a = 2, b = 4 },

怎么理解呢,看里面函数的功能就可以了,它对于alist的一个元素a,会将其和blist的每个元素组合一次最后生成[(a,3),(a,4)]。对alist里的每个a都来一遍这个函数就得到[ [(a0,3),(a0,4)], [(a1,3),(a1,4)] ],把它拍平就是最后的结果了。

然后以此类推,求三个列表的笛卡尔积:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int[] alist = { 1, 2 };
int[] blist = { 3, 4 };
int[] clist = { 4, 5 };

var result = alist.SelectMany(
a => blist.SelectMany(
b => clist.Select(
c => new { a, b, c }
)
)
);

foreach (var e in result)
{
Console.Write("{0},", e);
}

输出是:

1
{ a = 1, b = 3, c = 4 },{ a = 1, b = 3, c = 5 },{ a = 1, b = 4, c = 4 },{ a = 1, b = 4, c = 5 },{ a = 2, b = 3, c = 4 },{ a = 2, b = 3, c = 5 },{ a = 2, b = 4, c = 4 },{ a = 2, b = 4, c = 5 },

是不是稍微有一些难理解,其实你只要抓住一点就可以了:它接受的函数参数一定会返回一个列表(说IEnumerable才对,暂时可以理解成列表)。那么这里的函数返回的是个什么列表呢?看前面的例子就可以了:是把 $b_i$ 和 $c_i$ 组合一遍,然后前面加个a。最后所有的列表合在一起就是结果了。

这种嵌套的lambda是非常反人类而且低效的,然而scala就是这么干的(手动滑稽),所以在scala里还是少用for吧。后面会讲c#是怎么避过这个坑。先看另一种SelectMany。

接受两个函数的SelectMany

再回头看一下刚才的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
int[] alist = { 1, 2 };
int[] blist = { 3, 4 };

var result = alist.SelectMany(
a => blist.Select(
b => new { a, b }
)
);

foreach (var e in result)
{
Console.Write("{0},", e);
}

这种SelectMany里套一个Select的模式是非常常见的,所以没有必套2层lambda,c#直接提供了一个接受两个函数的SelectMany:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int[] alist = { 1, 2 };
int[] blist = { 3, 4 };

var result = alist.SelectMany(
_ => blist,
(a, b) => new { a, b }
);

foreach (var e in result)
{
Console.Write("{0},", e);
}
```
和前面的例子是等价的,输出:

{ a = 1, b = 3 },{ a = 1, b = 4 },{ a = 2, b = 3 },{ a = 2, b = 4 },

1
2
3
4
5
6
7
8
9

它的语义更直观,看下它的签名:

```c#
public static IEnumerable<TResult> SelectMany<TSource, TCollection, TResult>(
this IEnumerable<TSource> source,
Func<TSource, IEnumerable<TCollection>> collectionSelector,
Func<TSource, TCollection, TResult> resultSelector
);

意思就是对于source的每个元素a,调用colectionSelector生成一个列表l,然后对al的每个元素b,执行resultSelector(a,b),最后所有的结果组合在一起就是结果了。

再来看一个TSource,TCollection,TResult都不同的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
int[] ilist = { 1, 2 };
double[] dlist = { 0.1, 0.2, 0.3 };

Func<int, double[]> collectionSelector =
_ => dlist;
Func<int, double, string> resultSelector =
(int_num, double_num) => String.Format("'{0}'", int_num + double_num);

var result = ilist.SelectMany(collectionSelector, resultSelector);
foreach (var e in result)
{
Console.Write("{0},", e);
}

结果:

1
'1.1','1.2','1.3','2.1','2.2','2.3',

读者可自行体会。

翻译规则

那么写出它的方法调用形式:

1
2
3
4
numQuery1 = numbers.SelectMany (
_ => numbers,
(num0, num1) => new {num0, num1}
).Where( x => x.num0 + x.num1 >11);