blob: 61ec54b667c5728642a3f249fb450c831cff80d2 [file] [log] [blame]
Makoto Onukid1435462024-05-09 10:22:18 -07001#!/usr/bin/python3
2# Copyright (C) 2024 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# This script converts a legacy test class (using AndroidTestCase, TestCase or
17# InstrumentationTestCase to a modern style test class, in a best-effort manner.
18#
19# Usage:
20# convert-androidtest.py TARGET-FILE [TARGET-FILE ...]
21#
22# Caveats:
23# - It adds all the extra imports, even if they're not needed.
24# - It won't sort imports.
25# - It also always adds getContext() and getTestContext().
26#
27
28import sys
29import fileinput
30import re
31import subprocess
32
33# Print message on console
34def log(msg):
35 print(msg, file=sys.stderr)
36
37
38# Matches `extends AndroidTestCase` (or another similar base class)
39re_extends = re.compile(
40 r''' \b extends \s+ (AndroidTestCase|TestCase|InstrumentationTestCase) \s* ''',
41 re.S + re.X)
42
43
44# Look into given files and return the files that have `re_extends`.
45def find_target_files(files):
46 ret = []
47
48 for file in files:
49 try:
50 with open(file, 'r') as f:
51 data = f.read()
52
53 if re_extends.search(data):
54 ret.append(file)
55
56 except FileNotFoundError as e:
57 log(f'Failed to open file {file}: {e}')
58
59 return ret
60
61
62def main(args):
63 files = args
64
65 # Find the files that should be processed.
66 files = find_target_files(files)
67
68 if len(files) == 0:
69 log("No target files found.")
70 return 0
71
72 # Process the files.
73 with fileinput.input(files=(files), inplace = True, backup = '.bak') as f:
74 import_seen = False
75 carry_over = ''
76 class_body_started = False
77 class_seen = False
78
79 def on_file_start():
80 nonlocal import_seen, carry_over, class_body_started, class_seen
81 import_seen = False
82 carry_over = ''
83 class_body_started = False
84 class_seen = False
85
86 for line in f:
87 if (fileinput.filelineno() == 1):
88 log(f"Processing: {fileinput.filename()}")
89 on_file_start()
90
91 line = line.rstrip('\n')
92
93 # Carry over a certain line to the next line.
94 if re.search(r'''@Override\b''', line):
95 carry_over = carry_over + line + '\n'
96 continue
97
98 if carry_over:
99 line = carry_over + line
100 carry_over = ''
101
102
103 # Remove the base class from the class definition.
104 line = re_extends.sub('', line)
105
106 # Add a @RunWith.
107 if not class_seen and re.search(r'''\b class \b''', line, re.X):
108 class_seen = True
109 print("@RunWith(AndroidJUnit4.class)")
110
111
112 # Inject extra imports.
113 if not import_seen and re.search(r'''^import\b''', line):
114 import_seen = True
115 print("""\
116import android.content.Context;
117import androidx.test.platform.app.InstrumentationRegistry;
118
119import static junit.framework.TestCase.assertEquals;
120import static junit.framework.TestCase.assertSame;
121import static junit.framework.TestCase.assertNotSame;
122import static junit.framework.TestCase.assertTrue;
123import static junit.framework.TestCase.assertFalse;
124import static junit.framework.TestCase.assertNull;
125import static junit.framework.TestCase.assertNotNull;
126import static junit.framework.TestCase.fail;
127
128import org.junit.After;
129import org.junit.Before;
130import org.junit.runner.RunWith;
131import org.junit.Test;
132
133import androidx.test.ext.junit.runners.AndroidJUnit4;
134""")
135
136 # Add @Test to the test methods.
137 if re.search(r'''^ \s* public \s* void \s* test''', line, re.X):
138 print(" @Test")
139
140 # Convert setUp/tearDown to @Before/@After.
141 if re.search(r''' ^\s+ ( \@Override \s+ ) ? (public|protected) \s+ void \s+ (setUp|tearDown) ''',
142 line, re.X):
143 if re.search('setUp', line):
144 print(' @Before')
145 else:
146 print(' @After')
147
148 line = re.sub(r''' \s* \@Override \s* \n ''', '', line, 0, re.X)
149 line = re.sub(r'''protected''', 'public', line, 0, re.X)
150
151 # Remove the super setUp / tearDown call.
152 if re.search(r''' \b super \. (setUp|tearDown) \b ''', line, re.X):
153 continue
154
155 # Convert mContext to getContext().
156 line = re.sub(r'''\b mContext \b ''', 'getContext()', line, 0, re.X)
157
158 # Print the processed line.
159 print(line)
160
161 # Add getContext() / getTestContext() at the beginning of the class.
162 if not class_body_started and re.search(r'''\{''', line):
163 class_body_started = True
164 print("""\
165 private Context getContext() {
166 return InstrumentationRegistry.getInstrumentation().getTargetContext();
167 }
168
169 private Context getTestContext() {
170 return InstrumentationRegistry.getInstrumentation().getContext();
171 }
172""")
173
174
175 # Run diff
176 for file in files:
177 subprocess.call(["diff", "-u", "--color=auto", f"{file}.bak", file])
178
179 log(f'{len(files)} file(s) converted.')
180
181 return 0
182
183if __name__ == '__main__':
184 sys.exit(main(sys.argv[1:]))