Java 8 flatMap()

flatMap() method in java.util.stream.Stream interface is used to convert a stream of elements to another stream after applying a function on each element of the source stream.
Typically, flatMap() is used in situations where the source stream consists of complex objects such as a 2d array or a list of lists or a list of objects containing other objects and we want a simpler stream or a stream of inner objects.
In this article, we will understand flatMap() with ample examples.

Java doc of flatMap() states,

Returns a stream consisting of the results of replacing each element of this stream with the contents of a mapped stream produced by applying the provided mapping function to each element.

Its method signature is

<R> Stream<R> flatMap (
         Function<? super T, ? extends Stream<? extends R>> mapper
);

which shows that flatMap() accepts an argument of type java.util.function.Function interface and returns a stream.

Function is a functional interface having a single method apply(), which in turn accepts a single argument and returns a value.
If you look at the source of Function, it is as below

public interface Function<T, R> {

  R apply(T t);

}

So, the implementation of Function shall return a stream, which becomes the return value of flatMap().

Since Function is a functional interface, we can supply a Lambda expression as the argument to flatMap().
Java flatMap() example
Below is an example of using flatMap() to convert a stream of list of list of integers to a list of integers.
Stream of a list of list of integers is a complex stream consisting of nested objects. We will use flatMap() to convert it to a simpler stream.

// create list of numbers
List<Integer> list1 = List.of(1, 2, 3);
List<Integer> list2 = List.of(4, 5, 6);
List<Integer> list3 = List.of(7, 8, 9);
// create list of lists
List<List<Integer>> list = List.of(list1, list2, list3);
// convert to a flat stream
Stream<Integer> flatStream = list.
                             stream().
                             flatMap(Collection::stream);
// convert stream to list
List<Integer> result = flatStream.
                       collect(Collectors.toList());
System.out.println(result);

In this example, we first create 3 different lists of numbers using List.of() method and then create a list which contains these lists as its elements.

Then, we call the stream() method on this list of lists to obtain a stream of its elements.
Then, we call the flatMap method on this stream, passing it Collection::stream(), which is a shorthand of calling stream() method on java.util.Collection.
Collection.stream() returns a stream of elements of the collection on which it is called. In this case, it is the list of lists.
This flattens each individual list into a stream of its elements, which are then concatenated into a single stream of integers.
So, we have used flatMap() to convert a stream of list of lists of numbers into a stream of list of numbers.

Finally, we call the collect() method on this flattened stream to gather the results into a new list.

The resulting list will contain all the integers from the original list of lists in a single flattened stream.
flatMap() example with list of objects
Suppose we have a list of Employee objects, and each Employee object has a list of Project objects.
This list represents the projects that the employee has worked on.

class Employee {
  private String name;
  private List<Project> projects;
  
  public Employee(String name, List<Project> projects) {
    this.name = name;
    this.projects = projects;
  }

  public List<Project> getProjects() {
    return projects;
  }
}

class Project {
  private String name;
   
  public Project(String name) {
    this.name = name;
  }
}

List<Employee> employees = Arrays.asList(
    new Employee("A", Arrays.asList(new Project("Project A"), 
                        new Project("Project B"))),
    new Employee("B", Arrays.asList(new Project("Project B"), 
                        new Project("Project C"))),
    new Employee("C", Arrays.asList(new Project("Project A"),  
                        new Project("Project C")))
);

Now, we want to create a list of all the projects that all the employees have worked on.
We can use flatMap() to flatten the list of lists of projects into a single stream of projects as shown below

Stream<Project> projectStream = employees.
                                stream().
                                flatMap(
                                 e -> e.getProjects().stream()
                                );
List<Project> projects = projectStream.
                         collect(Collectors.toList());

In this example, we supply a lambda expression to flatMap() method, which is called for each employee element and flattens list of projects for each employee into a stream of projects, which are then concatenated into a single stream of projects.

Finally, we call the collect() method on this flattened stream to get the results into a new list of Project objects.

flatMap() to convert words into a list
Suppose we have a list of strings, and we want to split each string into a list of words, and then create a single stream of all the words using flatMap().
Below is an example

List<String> words = Arrays.asList("A B C", "X Y Z", "T U V");

List<String> wordList = words.
                        stream().
                        flatMap(
                         line -> Arrays.stream(line.split(" "))
                        ).
                        collect(Collectors.toList());
// Output: [A B C, X Y Z, T U V]

In this example, we create a list of words using Arrays.asList() method.
Then, we call the stream() method on this list to obtain a stream of its elements followed by flatMap() method on this stream.
Argument to flatMap() is a lambda expression that returns a stream of words for each string.
This splits each individual string into a stream of its words, which are then concatenated into a single stream of words.

Finally, we call the collect() method on this flattened stream to gather the results into a new list of strings.

The resulting list will contain all the words from the original list of strings in a single flattened stream.
Stream map() vs flatMap()
Below are some of the differences between Stream map() and flatMap() methods.

1. map() transforms each element in the stream into a single output element of the same type, while flatMap() can transform each element in the stream into zero or more output elements of different types.

2. flatMap() can be used to flatten a stream of collections or arrays, producing a single stream of the elements contained in the collections or arrays as we saw in the above examples.
While, map() cannot be used to flatten a stream of collections or arrays.

3. map() can only operate on single level of nesting in the stream, while flatMap() can access and manipulate elements at multiple levels of nesting in the stream.

4. map() is commonly used to apply a function to each element in the stream, such as converting each element to a different type, or performing a calculation based on the element’s value.
flatMap() is commonly used to split a stream of collections or arrays into a single stream of the elements contained in the collections or arrays.

Below is an example that will explain the difference between map() and flatMap()

class Person {
    private String name;
    private List gadgets;

    public Person(String name, List g) {
        this.name = name;
        this.gadgets = g;
    }

    public String getName() {
        return name;
    }

    public List getGadgets() {
        return gadgets;
    }
}

List people = Arrays.asList(
                new Person("A", 
                 Arrays.asList("iphone", "laptop", "smartwatch")),
  new Person("Bob", 
                 Arrays.asList("kindle","ipad")),
  new Person("Charlie", 
                 Arrays.asList("tablet"))
       );

// Using map:
List<List> gadgetListsMap = people.
                            stream().
                            map(Person::getGadgets).
       collect(Collectors.toList());
System.out.println(gadgetListsMap);
// Output: [[iphone, laptop, smartwatch], [kindle, ipad], [tablet]]

// Using flatMap:
List flatMap = people.
               stream().
               flatMap(person -> person.getGadgets().stream()).
               collect(Collectors.toList());
System.out.println(flatMap);
// Output: [iphone, laptop, smartwatch, kindle, ipad, tablet]

In this example, map() returns a stream of lists, where each list represents the gadgets of a single person.
flatMap(), on the other hand, returns a flattened stream of all the gadgets owned by all the people in the list.

Hope the article was useful.