[백준][JAVA] - 트리의 지름(1967)

문제 소개

🥇️ 문제 레벨 : 골드4

🔔 문제 유형 : 깊이 우선 탐색, 그래프

💬 풀이 언어 : JAVA

⏱️ 풀이 시간 : 60분

🖇️ 문제 링크 : 백준 문제 링크


📝 문제

트리(tree)는 사이클이 없는 무방향 그래프이다. 트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다. 트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다. 이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.

img1

이런 두 노드 사이의 경로의 길이를 트리의 지름이라고 한다. 정확히 정의하자면 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.

입력으로 루트가 있는 트리를 가중치가 있는 간선들로 줄 때, 트리의 지름을 구해서 출력하는 프로그램을 작성하시오. 아래와 같은 트리가 주어진다면 트리의 지름은 45가 된다.

img2

트리의 노드는 1부터 n까지 번호가 매겨져 있다.

🤔 문제 분석

첫 줄을 제외하고, N-1까지 입력되는 값을 보면 순서대로 부모노드 자식노드 가중치가 들어오게 된다.

12
1 2 3
1 3 2
2 4 5
3 5 11
3 6 9
4 7 1
4 8 7
5 9 15
5 10 4
6 11 6
6 12 10

이와 비슷한 형태의 문제는 대부분 그래프로 풀이가 가능하다.

그래프 입력 받기

우선 Node 클래스를 생성한 뒤, 생성자를 통해 연결된 노드의 번호와 가중치를 저장할 수 있게 구성해준다.

private static class Node {
    int linked, point;

    public Node(int linked, int point) {
        this.linked = linked;
        this.point = point;
    }
}

다음으로 그래프를 생성해 부모와 자식, 자식과 부모를 이어준다.

public class Main {
    static int N;
    
    // 그래프 리스트 배열
    static ArrayList<Node>[] lists;
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
        N = Integer.parseInt(br.readLine());

        lists = new ArrayList[N + 1];
        visited = new boolean[N + 1];

        // 그래프 배열 초기화
        for (int i = 1; i <= N; i++) {
            lists[i] = new ArrayList<>();
        }

        for (int i = 1; i < N; i++) {
            st = new StringTokenizer(br.readLine());
            int parent = Integer.parseInt(st.nextToken());
            int child = Integer.parseInt(st.nextToken());
            int point = Integer.parseInt(st.nextToken());
            
            // 부모 리스트에 자식 노드 연결
            lists[parent].add(new Node(child, point));
            // 자식 리스트에 부모 노드 연결
            lists[child].add(new Node(parent, point));
        }
        
        ...
    }
}

DFS를 통한 탐색

이제 시작 지점을 정하고, 최대값을 탐색하면 된다.

주어진 예시대로 첫 시작 지점을 1번 부터 시작했다고 가정하면, 아래와 같이 반복하게 된다.

img1

1번 노드부터 시작해서 가장 끝 노드까지 탐색하면서 누적합(sum)이 가장 큰 값을 가진 노드가 리프노드(시작점)이 된다.

private static void dfs(int nodeNum, int sum) {
    // 현재 노드 방문 처리
    visited[nodeNum] = true;

    for (Node node : lists[nodeNum]) {
        // 방문하지 않은 자식 노드 탐색
        if (!visited[node.linked]) {
            // 재귀를 통해 가장 마지막 노드까지 탐색
            dfs(node.linked, sum + node.point);
        }
    }
    
    // 끝 노드에 방문해야 진행되며, 기존 누적합이 현재 노드까지의 누적합보다 작을 경우
    if (result < sum) {
        // 누적합 최대값 최신화
        result = sum;
        // 시작할 리프 노드로 지정
        leafNode = nodeNum;
    }

}

😎 최종 코드

가장 마지막 풀이가 최적화된 코드입니다.

풀이1

모든 노드 중 부모 노드가 1개인 노드만 탐색해 가능성을 확인하는 코드

예제 입력 기준으로 총 6개의 노드(7, 8, 9, 10, 11, 12)에서 시작

메모리 : 124384 KB
시간 : 1860 ms
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class Main {
    static int N;
    static ArrayList<Node>[] lists;
    static boolean[] visited;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
        N = Integer.parseInt(br.readLine());

        lists = new ArrayList[N + 1];
        for (int i = 0; i <= N; i++) {
            lists[i] = new ArrayList<>();
        }

        for (int i = 1; i < N; i++) {
            st = new StringTokenizer(br.readLine());
            int parent = Integer.parseInt(st.nextToken());
            int child = Integer.parseInt(st.nextToken());
            int point = Integer.parseInt(st.nextToken());

            lists[child].add(new Node(parent, point));
            lists[parent].add(new Node(child, point));
        }

        int max = Integer.MIN_VALUE;

        for (int i = 1; i <= N; i++) {
            if (lists[i].size() == 1) {
                visited = new boolean[N + 1];
                max = Math.max(dfs(i), max);
            }

        }

        System.out.println(max == Integer.MIN_VALUE ? 0 : max);
    }

    private static int dfs(int nodeNum) {
        visited[nodeNum] = true;
        int maxSum = 0;

        for (Node node : lists[nodeNum]) {
            if (!visited[node.linked]) {
                int s = dfs(node.linked) + node.point;
                maxSum = Math.max(maxSum, s);
            }
        }

        return maxSum;
    }

    private static class Node {
        int linked, point;

        public Node(int linked, int point) {
            this.linked = linked;
            this.point = point;
        }
    }
}

풀이2

1번 노드에서 탐색을 시작해, 누적합이 가장 큰 노드를 찾아 최대값을 탐색하는 코드

예제 입력 기준으로 총 2개의 노드(1, 9)에서 시작

메모리 : 20840 KB
시간 : 188 ms
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class Main {
    static int N, leafNode, result;
    static ArrayList<Node>[] lists;
    static boolean[] visited;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
        N = Integer.parseInt(br.readLine());

        lists = new ArrayList[N + 1];
        for (int i = 0; i <= N; i++) {
            lists[i] = new ArrayList<>();
        }

        for (int i = 1; i < N; i++) {
            st = new StringTokenizer(br.readLine());
            int parent = Integer.parseInt(st.nextToken());
            int child = Integer.parseInt(st.nextToken());
            int point = Integer.parseInt(st.nextToken());

            lists[child].add(new Node(parent, point));
            lists[parent].add(new Node(child, point));
        }

        visited = new boolean[N + 1];

        dfs(1, 0);

        for (int i = 0; i <= N; i++) {
            visited[i] = false;
        }

        dfs(leafNode, 0);

        System.out.println(result);

        br.close();
    }

    private static void dfs(int nodeNum, int sum) {
        visited[nodeNum] = true;

        for (Node node : lists[nodeNum]) {
            if (!visited[node.linked]) {
                dfs(node.linked, sum + node.point);
            }
        }

        if (result < sum) {
            result = sum;
            leafNode = nodeNum;
        }

    }

    private static class Node {
        int linked, point;

        public Node(int linked, int point) {
            this.linked = linked;
            this.point = point;
        }
    }
}

풀이3

풀이 2번과 동일하지만 Node를 사용하지 않고, 배열을 이용한 풀이

메모리 : 20624 KB
시간 : 176 ms
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class Main {
    static int N, leafNode, result;
    static ArrayList<int[]>[] lists;
    static boolean[] visited;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
        N = Integer.parseInt(br.readLine());

        lists = new ArrayList[N + 1];
        for (int i = 0; i <= N; i++) {
            lists[i] = new ArrayList<>();
        }

        for (int i = 1; i < N; i++) {
            st = new StringTokenizer(br.readLine());
            int parent = Integer.parseInt(st.nextToken());
            int child = Integer.parseInt(st.nextToken());
            int point = Integer.parseInt(st.nextToken());

            lists[child].add(new int[]{parent, point});
            lists[parent].add(new int[]{child, point});
        }

        visited = new boolean[N + 1];

        dfs(1, 0);

        visited = new boolean[N + 1];

        dfs(leafNode, 0);

        System.out.println(result);

        br.close();
    }

    private static void dfs(int nodeNum, int sum) {
        visited[nodeNum] = true;

        for (int[] node : lists[nodeNum]) {
            if (!visited[node[0]]) {
                dfs(node[0], sum + node[1]);
            }
        }

        if (result < sum) {
            result = sum;
            leafNode = nodeNum;
        }

    }
}

댓글남기기