[백준][JAVA] - 트리의 지름(1967)
문제 소개
🥇️ 문제 레벨 : 골드4
🔔 문제 유형 : 깊이 우선 탐색, 그래프
💬 풀이 언어 : JAVA
⏱️ 풀이 시간 : 60분
🖇️ 문제 링크 : 백준 문제 링크
📝 문제
트리(tree)는 사이클이 없는 무방향 그래프이다. 트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다. 트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다. 이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.
이런 두 노드 사이의 경로의 길이를 트리의 지름이라고 한다. 정확히 정의하자면 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.
입력으로 루트가 있는 트리를 가중치가 있는 간선들로 줄 때, 트리의 지름을 구해서 출력하는 프로그램을 작성하시오. 아래와 같은 트리가 주어진다면 트리의 지름은 45가 된다.
트리의 노드는 1부터 n까지 번호가 매겨져 있다.
🤔 문제 분석
첫 줄을 제외하고, N-1
까지 입력되는 값을 보면 순서대로 부모노드 자식노드 가중치
가 들어오게 된다.
12
1 2 3
1 3 2
2 4 5
3 5 11
3 6 9
4 7 1
4 8 7
5 9 15
5 10 4
6 11 6
6 12 10
이와 비슷한 형태의 문제는 대부분 그래프로 풀이가 가능하다.
그래프 입력 받기
우선 Node
클래스를 생성한 뒤, 생성자를 통해 연결된 노드의 번호와 가중치를 저장할 수 있게 구성해준다.
private static class Node {
int linked, point;
public Node(int linked, int point) {
this.linked = linked;
this.point = point;
}
}
다음으로 그래프를 생성해 부모와 자식, 자식과 부모를 이어준다.
public class Main {
static int N;
// 그래프 리스트 배열
static ArrayList<Node>[] lists;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
N = Integer.parseInt(br.readLine());
lists = new ArrayList[N + 1];
visited = new boolean[N + 1];
// 그래프 배열 초기화
for (int i = 1; i <= N; i++) {
lists[i] = new ArrayList<>();
}
for (int i = 1; i < N; i++) {
st = new StringTokenizer(br.readLine());
int parent = Integer.parseInt(st.nextToken());
int child = Integer.parseInt(st.nextToken());
int point = Integer.parseInt(st.nextToken());
// 부모 리스트에 자식 노드 연결
lists[parent].add(new Node(child, point));
// 자식 리스트에 부모 노드 연결
lists[child].add(new Node(parent, point));
}
...
}
}
DFS를 통한 탐색
이제 시작 지점을 정하고, 최대값을 탐색하면 된다.
주어진 예시대로 첫 시작 지점을 1번 부터 시작했다고 가정하면, 아래와 같이 반복하게 된다.
1번 노드부터 시작해서 가장 끝 노드까지 탐색하면서 누적합(sum)이 가장 큰 값을 가진 노드가 리프노드(시작점)이 된다.
private static void dfs(int nodeNum, int sum) {
// 현재 노드 방문 처리
visited[nodeNum] = true;
for (Node node : lists[nodeNum]) {
// 방문하지 않은 자식 노드 탐색
if (!visited[node.linked]) {
// 재귀를 통해 가장 마지막 노드까지 탐색
dfs(node.linked, sum + node.point);
}
}
// 끝 노드에 방문해야 진행되며, 기존 누적합이 현재 노드까지의 누적합보다 작을 경우
if (result < sum) {
// 누적합 최대값 최신화
result = sum;
// 시작할 리프 노드로 지정
leafNode = nodeNum;
}
}
😎 최종 코드
가장 마지막 풀이가 최적화된 코드입니다.
풀이1
모든 노드 중 부모 노드가 1개인 노드만 탐색해 가능성을 확인하는 코드
예제 입력 기준으로 총 6개의 노드(7, 8, 9, 10, 11, 12)에서 시작
메모리 : 124384 KB
시간 : 1860 ms
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;
public class Main {
static int N;
static ArrayList<Node>[] lists;
static boolean[] visited;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
N = Integer.parseInt(br.readLine());
lists = new ArrayList[N + 1];
for (int i = 0; i <= N; i++) {
lists[i] = new ArrayList<>();
}
for (int i = 1; i < N; i++) {
st = new StringTokenizer(br.readLine());
int parent = Integer.parseInt(st.nextToken());
int child = Integer.parseInt(st.nextToken());
int point = Integer.parseInt(st.nextToken());
lists[child].add(new Node(parent, point));
lists[parent].add(new Node(child, point));
}
int max = Integer.MIN_VALUE;
for (int i = 1; i <= N; i++) {
if (lists[i].size() == 1) {
visited = new boolean[N + 1];
max = Math.max(dfs(i), max);
}
}
System.out.println(max == Integer.MIN_VALUE ? 0 : max);
}
private static int dfs(int nodeNum) {
visited[nodeNum] = true;
int maxSum = 0;
for (Node node : lists[nodeNum]) {
if (!visited[node.linked]) {
int s = dfs(node.linked) + node.point;
maxSum = Math.max(maxSum, s);
}
}
return maxSum;
}
private static class Node {
int linked, point;
public Node(int linked, int point) {
this.linked = linked;
this.point = point;
}
}
}
풀이2
1번 노드에서 탐색을 시작해, 누적합이 가장 큰 노드를 찾아 최대값을 탐색하는 코드
예제 입력 기준으로 총 2개의 노드(1, 9)에서 시작
메모리 : 20840 KB
시간 : 188 ms
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;
public class Main {
static int N, leafNode, result;
static ArrayList<Node>[] lists;
static boolean[] visited;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
N = Integer.parseInt(br.readLine());
lists = new ArrayList[N + 1];
for (int i = 0; i <= N; i++) {
lists[i] = new ArrayList<>();
}
for (int i = 1; i < N; i++) {
st = new StringTokenizer(br.readLine());
int parent = Integer.parseInt(st.nextToken());
int child = Integer.parseInt(st.nextToken());
int point = Integer.parseInt(st.nextToken());
lists[child].add(new Node(parent, point));
lists[parent].add(new Node(child, point));
}
visited = new boolean[N + 1];
dfs(1, 0);
for (int i = 0; i <= N; i++) {
visited[i] = false;
}
dfs(leafNode, 0);
System.out.println(result);
br.close();
}
private static void dfs(int nodeNum, int sum) {
visited[nodeNum] = true;
for (Node node : lists[nodeNum]) {
if (!visited[node.linked]) {
dfs(node.linked, sum + node.point);
}
}
if (result < sum) {
result = sum;
leafNode = nodeNum;
}
}
private static class Node {
int linked, point;
public Node(int linked, int point) {
this.linked = linked;
this.point = point;
}
}
}
풀이3
풀이 2번과 동일하지만 Node
를 사용하지 않고, 배열을 이용한 풀이
메모리 : 20624 KB
시간 : 176 ms
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;
public class Main {
static int N, leafNode, result;
static ArrayList<int[]>[] lists;
static boolean[] visited;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
N = Integer.parseInt(br.readLine());
lists = new ArrayList[N + 1];
for (int i = 0; i <= N; i++) {
lists[i] = new ArrayList<>();
}
for (int i = 1; i < N; i++) {
st = new StringTokenizer(br.readLine());
int parent = Integer.parseInt(st.nextToken());
int child = Integer.parseInt(st.nextToken());
int point = Integer.parseInt(st.nextToken());
lists[child].add(new int[]{parent, point});
lists[parent].add(new int[]{child, point});
}
visited = new boolean[N + 1];
dfs(1, 0);
visited = new boolean[N + 1];
dfs(leafNode, 0);
System.out.println(result);
br.close();
}
private static void dfs(int nodeNum, int sum) {
visited[nodeNum] = true;
for (int[] node : lists[nodeNum]) {
if (!visited[node[0]]) {
dfs(node[0], sum + node[1]);
}
}
if (result < sum) {
result = sum;
leafNode = nodeNum;
}
}
}
댓글남기기