从lambda表达式开始说起           

笔者在开发中遇到这样一个需求:从上游收到一个集合,要求按照集合元素的某个属性进行分组,最终需要得到一个key为分组内元素的属性值,value为组内元素id的集合。笔者尝试用一次stream表达式完成均未成功,最终通过查询Api才顺利完成。这让笔者意识到对stream的了解比较肤浅,于是抽出时间专门研究学习一下。具体例子如下所示:

//**before**
List<Hosts> hosts = hostsRepository.find(hostQuery);
Map<Boolean, List<Hosts>> msEnableHostMap = hosts.stream()
    .collect(Collectors.partitioningBy(
        host -> host.getMsEnabled() == Hosts.MS_ENABLE));
microSegClientService.updateSwitch(
    msEnableHostMap.get(Boolean.TRUE).stream().map(Hosts::getAgentId).collect(toList()), 
    Hosts.MS_ENABLE);
microSegClientService.updateSwitch(
    msEnableHostMap.get(Boolean.FALSE).stream().map(Hosts::getAgentId).collect(toList()), 
    Hosts.MS_DISABLE);
        
//**after**
List<Hosts> hosts = hostsRepository.find(hostQuery);
Map<Boolean, List<String>> msEnableHostMap = hosts.stream()
    .collect(partitioningBy(host -> host.getMsEnabled() == Hosts.MS_ENABLE,
                Collectors.mapping(Hosts::getAgentId, toList())));
microSegClientService.updateSwitch(msEnableHostMap.get(Boolean.TRUE), Hosts.MS_ENABLE);
microSegClientService.updateSwitch(msEnableHostMap.get(Boolean.FALSE), Hosts.MS_DISABLE);

重写之后代码由12行变为6行。这里有一个小插曲,在笔者准备整理成文档时,不知是巧合还是精准推荐,“阿里开发者”公众号推送:万字长文详解Java lambda表达式。笔者阅读总体感觉太”干”了,和实践的联系稍微少一些,对于略懂stream但在使用中遇到问题的读者,达不到quick start的目的。因此笔者在借鉴之余,还是继续完成此篇文章。写这篇文章的目的有二:

  1. 由简单到稍微复杂地举几个使用例子,方便需要使用的读者复制稍微修改一下即可使用,起到quick start目的。笔者尽量提供所能想到的,后续随时补充
  2. 结合实例讲解源码,这部分希望达到两个目的:
    1. 读者可以熟练地使用JDK提供的Api
    2. 笔者努力领会代码的精髓和奥妙,希望对编程水平有所提升

Stream 快速使用

在本示例中,笔者首先自定义一个集合元素类Man,如下所示:

class Man {
    Integer id;
    String name;
    Integer age;
}

获取姓名的集合

/**
  * [201, 202, 203, 204, 202, 205]
  */
List<String> names = men.stream().map(Man::getName).collect(Collectors.toList());

根据年龄来分组

/**
  * {21=[Man{id=1, name='201', age=21}], 22=[Man{id=2, name='202', age=22}, 
  *  Man{id=6, name='205', age=22}], 23=[Man{id=3, name='203', age=23}, 
  *  Man{id=5, name='202', age=23}], 24=[Man{id=4, name='204', age=24}]}
 */
Map<Integer, List<Man>> ageManMap = men.stream().collect(Collectors.groupingBy(Man::getAge));

根据年龄分组,只返回每个人的id

/**
  * {21=[1], 22=[2, 6], 23=[3, 5], 24=[4]}
  */
Map<Integer, List<Integer>> ageManIdMap = men.stream().collect(Collectors.groupingBy(Man::getAge,
    Collectors.mapping(Man::getId, Collectors.toList())));

根据年龄分组,只返回id最大的Man

/**
 * {21=Optional[Man{id=1, name='201', age=21}], 22=Optional[Man{id=6, name='205', age=22}],
 * 23=Optional[Man{id=5, name='202', age=23}], 24=Optional[Man{id=4, name='204', age=24}]}
 */
Map<Integer, Optional<Man>> ageIdMaxMan = men.stream().collect(Collectors.groupingBy(Man::getAge, 
    Collectors.maxBy((left, right) -> left.age - right.age)));

根据年龄分组,每组取最后一个

/**
 * {21=Man{id=1, name='201', age=21},22=Man{id=6, name='205', age=22},
 * 23=Man{id=5, name='202', age=23},24=Man{id=4, name='204', age=24}}
 */
Map<Integer, Man> ageLastMan = men.stream().collect(Collectors.groupingBy(Man::getAge,
     Collectors.collectingAndThen(Collectors.toList(), t -> t.get(t.size() - 1))));

代码实现

stream的所有操作分为两类:中间操作和终止操作。中间操作是指处理完毕生成一个新的stream,传给下游继续处理。终止操作是指生成最终的结果,stream 到此就结束了。中间操作主要有(图抄自网上):

终止操作主要有(图抄自网上):

中间操作比较简单。相对而言,终止操作稍微复杂一些。终止操作有很多,但可以分为四类:find、foreach、match和reduce,其余操作底层都是以其中之一实现的,平常使用最多的collect方法就是以reduce来实现的。

stream分为两类,一种是串行(stream),一种是并行(parallelStream)。中间操作也分为两类,无状态和有状态,像filter、map就是无状态,limit、distinct就是有状态。笔者先讲解串行实现,然后讲解并行实现。

Java实现stream的总体思想是链式,将整个操作链路抽象成链表,其中每个操作抽象成链表中的节点。当遇到终止操作(对应终止节点)时,从头部遍历链表各个节点,得到最终结果。精简后的类图如下所示:

中间操作抽象成了ReferencePipeline类,有三个实现:HEAD/StatelessOp/StatefulOp,对应头部/无状态/有状态操作。其中HEAD表示是链表的起点,由${Collection}.stream()操作触发创建。终止操作抽象成了TerminalOp类, 有四个实现:FindOp/ForEachOp/ReduceOp/MatchOp, 对应上述的四类终止操作。终止操作和中间操作都依赖Sink接口,Sink定义了链表每个节点对数据的操作。综上所述,ReferencePipeline定义了中间节点、TerminalOp定义了终止节点、Sink定义了每个节点的数据处理。笔者首先分析下这三个类的定义:

ReferencePipeline

ReferencePipeline采用了模板实现方式,和链表操作相关的逻辑在AbstractPipeline的构造方法中。代码如下所示:

abstract class AbstractPipeline<E_IN, E_OUT, S extends BaseStream<E_OUT, S>>
        extends PipelineHelper<E_OUT> implements BaseStream<E_OUT, S> {

    private final AbstractPipeline sourceStage;
    @SuppressWarnings("rawtypes")
    private final AbstractPipeline previousStage;
    protected final int sourceOrOpFlags;
    @SuppressWarnings("rawtypes")
    private AbstractPipeline nextStage;
    private int depth;   
    private boolean linkedOrConsumed;
    private boolean parallel;

    AbstractPipeline(Supplier<? extends Spliterator<?>> source,
                     int sourceFlags, boolean parallel) {
        this.previousStage = null;
        this.sourceSupplier = source;
        this.sourceStage = this;
        this.sourceOrOpFlags = sourceFlags & StreamOpFlag.STREAM_MASK;
        this.combinedFlags = (~(sourceOrOpFlags << 1)) & StreamOpFlag.INITIAL_OPS_VALUE;
        this.depth = 0;
        this.parallel = parallel;
    }

    AbstractPipeline(AbstractPipeline<?, E_IN, ?> previousStage, int opFlags) {
        if (previousStage.linkedOrConsumed)
            throw new IllegalStateException(MSG_STREAM_LINKED);
        previousStage.linkedOrConsumed = true;
        previousStage.nextStage = this;

        this.previousStage = previousStage;
        this.sourceOrOpFlags = opFlags & StreamOpFlag.OP_MASK;
        this.combinedFlags = StreamOpFlag.combineOpFlags(opFlags, previousStage.combinedFlags);
        this.sourceStage = previousStage.sourceStage;
        if (opIsStateful())
            sourceStage.sourceAnyStateful = true;
        this.depth = previousStage.depth + 1;
    }
}

创建HEAD节点使用第一个构造器,其他节点使用第二个构造器。可以看到构造的是一个双向链表,每个AbstractPipeline都有指向前一个节点的previousStage和后一个节点的nextStage,还有一个表征节点在链表中位置的depth。sourceStage表示起始节点,便于找到HEAD节点而设置的。

Sink

Sink作为定义数据处理的接口,有四个方法,如下所示:

interface Sink<T> extends Consumer<T> {
   default void begin(long size) {}
   default void end() {}
   default boolean cancellationRequested() {return false;}
   default void accept(int value) {
      throw new IllegalStateException("called wrong accept method");
   }
}

当集合元素开始遍历处理之前,先调用begin, 然后对每个集合元素都调用accept, 遍历完毕调用end方法。cancellationRequested 方法表示节点不希望接受元素了(笔者也没搞明白什么场景会调用)。

TerminalOp

TerminalOp定义了终止操作,在接口定义中两个方法evaluateParallel和evaluateSequential,分别在并行和串行时调用。每个TerminalOp的实现类,都会实现这两个方法,代码如下所示:

interface TerminalOp<E_IN, R> {
    default StreamShape inputShape() { return StreamShape.REFERENCE; }

    default int getOpFlags() { return 0; }

    default <P_IN> R evaluateParallel(PipelineHelper<E_IN> helper,
                                      Spliterator<P_IN> spliterator) {
        if (Tripwire.ENABLED)
            Tripwire.trip(getClass(), "{0} triggering TerminalOp.evaluateParallel serial default");
        return evaluateSequential(helper, spliterator);
    }

    <P_IN> R evaluateSequential(PipelineHelper<E_IN> helper,
                                Spliterator<P_IN> spliterator);
}

结合实例

笔者以具体的例子来讲解代码,仍然以上述Man类为例,获取所有Man的name的集合,有三种实现方式,如下所示:

names = men.stream().map(Man::getName).reduce(
            new ArrayList<String>(),
            (left, right) -> {left.add(right);return left;},
            (left, right) -> { left.addAll(right); return left; });

names = men.stream().map(Man::getName).collect(
            ArrayList::new,
            List::add,
            (left, right) -> left.addAll(right));

names = men.stream().map(Man::getName).collect(Collectors.toList());

可以看到终止操作用了三种方式,一种reduce和两种collect。笔者将对三种实现方式都进行讲解,以探究其中的区别。

创建Stream

public static <T> Stream<T> stream(Spliterator<T> spliterator, boolean parallel) {
        Objects.requireNonNull(spliterator);
        return new ReferencePipeline.Head<>(spliterator,
                                            StreamOpFlag.fromCharacteristics(spliterator),
                                            parallel);
}

stream方法创建了一个ReferencePipeline.Head的头结点。其中spliterator是对集合Collection的封装,parallel是标志位,stream()方法为false,parallelStream()为true。

map

public final <R> Stream<R> map(Function<? super P_OUT, ? extends R> mapper) {
        Objects.requireNonNull(mapper);
        return new StatelessOp<P_OUT, R>(this, StreamShape.REFERENCE,
                                     StreamOpFlag.NOT_SORTED | StreamOpFlag.NOT_DISTINCT) {
            @Override
            Sink<P_OUT> opWrapSink(int flags, Sink<R> sink) {
                return new Sink.ChainedReference<P_OUT, R>(sink) {
                    @Override
                    public void accept(P_OUT u) {
                        downstream.accept(mapper.apply(u));
                    }
                };
            }
        };
    }

map方法创建了一个ReferencePipeline.StatelessOp对象。和创建HEAD不同的是,StatelessOp对象实现了opWrapSink方法。opWrapSink中创建了一个Sink的实现类Sink.ChainedReference,实现了accept方法。实现逻辑也比较简单,将迭代元素经传入的Lambda处理之后,给下游的accept方法继续处理。到这个地方,读者可能会意识到,这种链式设计思路,使每迭代一个元素,就会经历所有的节点处理,而不是等待上一个节点处理完毕所有元素,然后下个节点继续处理。

reduce

reduce的主要操作是在ReduceOps的makeRef方法中,具体逻辑如下:

public static <T, U> TerminalOp<T, U>
    makeRef(U seed, BiFunction<U, ? super T, U> reducer, BinaryOperator<U> combiner) {
        class ReducingSink extends Box<U> implements AccumulatingSink<T, U, ReducingSink> {
            @Override
            public void begin(long size) {
                state = seed;
            }

            @Override
            public void accept(T t) {
                state = reducer.apply(state, t);
            }

            @Override
            public void combine(ReducingSink other) {
                state = combiner.apply(state, other.state);
            }
        }
        return new ReduceOp<T, U, ReducingSink>(StreamShape.REFERENCE) {
            @Override
            public ReducingSink makeSink() {
                return new ReducingSink();
            }
        };
    }

其中主要操作是生成了ReduceOp对象,并实现了 makeSink 方法。只是这里创建的是ReduceSink,相比起map操作中的Sink.ChainedReference,新增了combine方法。创建完成ReduceOp之后,需要执行具体的迭代处理操作,主要逻辑在AbstractPipeline类的wrapAndCopyInto中。wrap的目的是将链表上各个节点的Sink定义串联起来,形成一个新的单向链表。copyInfo则是启动集合的迭代操作,将每个集合元素传入头部Sink处理在迭代完毕之后,从Sink中取出最终结果并返回。代码如下所示:

final <P_IN> Sink<P_IN> wrapSink(Sink<E_OUT> sink) {
        Objects.requireNonNull(sink);

        for(@SuppressWarnings("rawtypes") AbstractPipeline p=AbstractPipeline.this; p.depth > 0; 
               p=p.previousStage) {
            sink = p.opWrapSink(p.previousStage.combinedFlags, sink);
        }
        return (Sink<P_IN>) sink;
 }

 @Override
 final <P_IN> void copyInto(Sink<P_IN> wrappedSink, Spliterator<P_IN> spliterator) {
        Objects.requireNonNull(wrappedSink);

        if (!StreamOpFlag.SHORT_CIRCUIT.isKnown(getStreamAndOpFlags())) {
            wrappedSink.begin(spliterator.getExactSizeIfKnown());
            spliterator.forEachRemaining(wrappedSink);
            wrappedSink.end();
        }
        else {
            copyIntoWithCancel(wrappedSink, spliterator);
        }
    }

可以看到wrapSink方法是将ReduceOp中创建的ReduceSink(通过makeSink方法)传入,然后沿着链表从后向前调用每个节点的opWrapSink方法,最终将所有链表操作串联起来。而copyInto操作则是如上所述,先调用begin方法,然后进行集合元素迭代,最后调用end。结合实例中的reduce参数seed为new ArrayList(), reducer为(left, right) -> {left.add(right);return left;}, combiner为 (left, right) -> { left.addAll(right); return left; })。在执行copyInto的时候,通过调用begin方法将seed复制给state,然后每次遍历,就将state和集合元素一起给到reducer并将结果重新赋值给state,而reducer的定义也是将元素right放入到left的集合,并返回left。可以看到在串行操作中只用到了begin和accept方法,感兴趣的读者可以将combiner设置成非空任意值,在串行时也是正常运行的。

那么示例中的collect是如何实现的呢?Collectors.toList()的源码如下所示:

public static <T> Collector<T, ?, List<T>> toList() {
        return new CollectorImpl<>((Supplier<List<T>>) ArrayList::new, List::add,
                                   (left, right) -> { left.addAll(right); return left; },
                                   CH_ID);
}

public interface Collector<T, A, R> {
    Supplier<A> supplier();
    BiConsumer<A, T> accumulator();
    BinaryOperator<A> combiner();
    Function<A, R> finisher();
    Set<Characteristics> characteristics();
}

Collectors.toList()是生成了一个Collector接口的实现,并且覆写了supplier、accumulator和combiner。接下来看下collect方法的具体实现,如下所示:

//.collect(Collectors.toList())
public static <T, I> TerminalOp<T, I>
    makeRef(Collector<? super T, I, ?> collector) {
        Supplier<I> supplier = Objects.requireNonNull(collector).supplier();
        BiConsumer<I, ? super T> accumulator = collector.accumulator();
        BinaryOperator<I> combiner = collector.combiner();
        class ReducingSink extends Box<I>
                implements AccumulatingSink<T, I, ReducingSink> {
            @Override
            public void begin(long size) {
                state = supplier.get();
            }

            @Override
            public void accept(T t) {
                accumulator.accept(state, t);
            }

            @Override
            public void combine(ReducingSink other) {
                state = combiner.apply(state, other.state);
            }
        }
        return new ReduceOp<T, I, ReducingSink>(StreamShape.REFERENCE) {
            @Override
            public ReducingSink makeSink() {
                return new ReducingSink();
            }

            @Override
            public int getOpFlags() {
                return collector.characteristics().contains(Collector.Characteristics.UNORDERED)
                       ? StreamOpFlag.NOT_ORDERED
                       : 0;
            }
        };
    }

//.collect(ArrayList::new, 
               List::add, 
             (left, right) -> left.addAll(right))
public static <T, R> TerminalOp<T, R>
    makeRef(Supplier<R> seedFactory,
            BiConsumer<R, ? super T> accumulator,
            BiConsumer<R,R> reducer) {
        Objects.requireNonNull(seedFactory);
        Objects.requireNonNull(accumulator);
        Objects.requireNonNull(reducer);
        class ReducingSink extends Box<R>
                implements AccumulatingSink<T, R, ReducingSink> {
            @Override
            public void begin(long size) {
                state = seedFactory.get();
            }

            @Override
            public void accept(T t) {
                accumulator.accept(state, t);
            }

            @Override
            public void combine(ReducingSink other) {
                reducer.accept(state, other.state);
            }
        }
        return new ReduceOp<T, R, ReducingSink>(StreamShape.REFERENCE) {
            @Override
            public ReducingSink makeSink() {
                return new ReducingSink();
            }
        };
    }

两个collect方法的区别在于前者的形参是一个Collector接口,而后者的形参是三个功能性接口(Functional Interface)。结合Collector接口的定义,就是将三个功能性接口合并到了Collector接口中。当然需要注意的是,Collector接口中第三个是BiFunction,而后者的collect方法的第三个参数是BiConsumer,因此不需要将执行结果赋值给state。那么这三个功能性接口分别起到的作用是什么呢,更具体地说他们是什么时候调用呢?其实在封装后的Collector中,一共有四个功能性接口,他们的功能分别如下:

接口名称 作用
supplier 提供一个种子值(seed),作为起始元素
accumulator 迭代集合元素,和seed进行处理,并将处理结果复制给seed
combiner 只在并行处理的时候有效。对每个segment返回的seed进行合并处理
finisher 生成最终结果的时候,可以进一步处理,得到预期想要的结果

笔者猜测,最早出现的是reduce方法,可能因为某种原因出现了collect方法和Collector,Collector将collect的形式参数进行了封装,并提供了助于toList、toMap等实现。至于从reduce到collect的原因,笔者怀疑可能有两点:

  1. reduce方法的第一个参数是泛型类型,而不是Functional interface,可能破坏了风格的统一,影响了封装并对外提供强大的功能。
  2. 可能其他语言用的是collect

关于串行无状态操作就讲解完毕了。笔者不禁反思,如果这个功能由笔者来实现,会呈现出什么样的效果?笔者稍加思索之后,可能首先想到的是如下的形式。对比一下,也许可以更多地领略设计的奥妙。

public class Stream<T> {
    private List<T> list;

    public Stream(List<T> list) {
        this.list = list;
    }

    Stream<T> filter(Predicate<T> predicate) {
        for (Iterator<T> iterator = list.iterator(); iterator.hasNext();) {
            if (!predicate.test(iterator.next())) {
                iterator.remove();
            }
        }
        return this;
    }

    <R> Stream<R> map(Function<T, R> function) {
        List<R> newList = new ArrayList<>();
        for (Iterator<T> iterator = list.iterator(); iterator.hasNext();) {
            newList.add(function.apply(iterator.next()));
        }
        return new Stream<>(newList);
    }

    public long count() {
        return list.size();
    }

    void foreach(Consumer<T> consumer) {
        for (Iterator<T> iterator = list.iterator(); iterator.hasNext();) {
            consumer.accept(iterator.next());
        }
    }
}

从效率上将,如果有N个操作,Java实现只需遍历一次,而笔者实现则需要遍历N次,时间复杂度是N* N,因此Java实现性能更好。第二个较大的问题,如果按照笔者实现,每个阶段要操作完所有元素,并行处理效率很低,每一个阶段都要等上一个阶段处理完毕,而针对并发有状态更是无从谈起。笔者认为这个可能是设计的最主要原因。按照链式的设计思路,在串行的基础上如何实现并行呢,只需在TerminalOp类中新增一个方法就行,再次体现了设计良好的扩展性。主要逻辑在AbstractTask中的compute和onCompletion方法中。

@Override
public void compute() {
   Spliterator<P_IN> rs = spliterator, ls; // right, left spliterators
   long sizeEstimate = rs.estimateSize();
   long sizeThreshold = getTargetSize(sizeEstimate);
   boolean forkRight = false;
   @SuppressWarnings("unchecked") K task = (K) this;
   while (sizeEstimate > sizeThreshold && (ls = rs.trySplit()) != null) {
      K leftChild, rightChild, taskToFork;
      task.leftChild  = leftChild = task.makeChild(ls);
      task.rightChild = rightChild = task.makeChild(rs);
      task.setPendingCount(1);
      if (forkRight) {
         forkRight = false;
         rs = ls;
         task = leftChild;
         taskToFork = rightChild;
      } else {
         forkRight = true;
         task = rightChild;
         taskToFork = leftChild;
      }
      taskToFork.fork();
      sizeEstimate = rs.estimateSize();
   }
   task.setLocalResult(task.doLeaf());
   task.tryComplete();
}

public void onCompletion(CountedCompleter<?> caller) {
   if (!isLeaf()) {
      S leftResult = leftChild.getLocalResult();
      leftResult.combine(rightChild.getLocalResult());
      setLocalResult(leftResult);
   }
            // GC spliterator, left and right child
   super.onCompletion(caller);
}

这个逻辑相对而言比较清晰,就是不断把任务进行划分,分成leftchild和rightChild。到了最小分区的时候,才进行处理。处理完毕,会调用onCompletion方法(因为AbstractTask继承的是CountedCompleter),在会调用onCompletion方法中调用combine方法将处理结果进行合并(在这里可以看到只有在并行处理的时候才会调用combine方法),这是典型的Map/Reduce思路。

  
如果对文章有任何不同意见,有两种办法: