diff --git a/build.sla b/build.sla index 6d22eae..38d6f61 100644 --- a/build.sla +++ b/build.sla @@ -22,7 +22,7 @@ test () { # Code quality linting (pylint) cd $ROOTDIR - cd src && pylint --reports=n -dR -d star-args -d no-member *.py || true + cd src && pylint --reports=n -dR -d subprocess-popen-preexec-fn -d invalid-name -d star-args -d no-member *.py || true cd $ROOTDIR } diff --git a/examples/auth/auth.json b/examples/auth/auth.json index 64fb752..a8a0751 100644 --- a/examples/auth/auth.json +++ b/examples/auth/auth.json @@ -8,7 +8,7 @@ { "name": "do_nothing", "title": "Test form", - "description": "You should only see this if you've entered the correct password", + "description": "All users that logged in should be able to see this.", "submit_title": "Do nothing", "script": "job_do_nothing.sh", "fields": [ diff --git a/examples/simple/job_import.sh b/examples/simple/job_import.sh index e589470..c50bdce 100755 --- a/examples/simple/job_import.sh +++ b/examples/simple/job_import.sh @@ -6,5 +6,9 @@ MYSQL="mysql --defaults-file=$MYSQL_DEFAULTS_FILE" echo "This is what would be executed if this wasn't a fake script:" echo -echo "echo 'DROP DATABASE $target_db' | $MYSQL" -echo "$MYSQL ${target_db} < ${sql_file}" +echo " echo 'DROP DATABASE $target_db' | $MYSQL" +echo " $MYSQL ${target_db} < ${sql_file}" + +echo +echo "The uploaded file was $(stat --printf="%s" $sql_file) bytes" +echo "The (binary) md5 hash of the uploaded file is: $(md5sum -b $sql_file | cut -d " " -f1)" diff --git a/src/daemon.py b/src/daemon.py index e2264e6..747ac03 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -15,7 +15,6 @@ class DaemonError(Exception): """ Default error for Daemon class. """ - pass class Daemon(object): # pragma: no cover @@ -97,7 +96,8 @@ class Daemon(object): # pragma: no cover return None try: - pid = int(file(self.pid_file, 'r').read().strip()) + with open(self.pid_file, "r") as fh: + pid = int(fh.read().strip()) except ValueError: return None @@ -137,9 +137,8 @@ class Daemon(object): # pragma: no cover pid = os.fork() if pid > 0: self.log.info("PID = %s", pid) - pidfile = file(self.pid_file, 'w') - pidfile.write(str(pid)) - pidfile.close() + with open(self.pid_file, "w") as fh: + fh.write(str(pid)) sys.exit(0) # End parent atexit.register(self._cleanup) diff --git a/src/formconfig.py b/src/formconfig.py index cc9f3a6..a3c26e5 100644 --- a/src/formconfig.py +++ b/src/formconfig.py @@ -13,7 +13,6 @@ class FormConfigError(Exception): """ Default error for FormConfig errors """ - pass class FormConfig(object): diff --git a/src/formdefinition.py b/src/formdefinition.py index 7def906..4e0e7aa 100644 --- a/src/formdefinition.py +++ b/src/formdefinition.py @@ -10,8 +10,9 @@ import runscript class ValidationError(Exception): - """Default exception for Validation errors""" - pass + """ + Default exception for Validation errors + """ class FormDefinition(object): diff --git a/src/scriptform.py b/src/scriptform.py index 6c637ee..d37485b 100755 --- a/src/scriptform.py +++ b/src/scriptform.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ @@ -6,13 +6,12 @@ Main ScriptForm program """ import sys -import optparse +import argparse import os import json import logging -import thread +import threading import hashlib -import socket if hasattr(sys, 'dont_write_bytecode'): sys.dont_write_bytecode = True @@ -51,7 +50,8 @@ class ScriptForm(object): if self.cache and self.form_config_singleton is not None: return self.form_config_singleton - file_contents = file(self.config_file, 'r').read() + with open(self.config_file, "r") as fh: + file_contents = fh.read() try: config = json.loads(file_contents) except ValueError as err: @@ -67,7 +67,8 @@ class ScriptForm(object): if 'static_dir' in config: static_dir = config['static_dir'] if 'custom_css' in config: - custom_css = file(config['custom_css'], 'r').read() + with open(config["custom_css"], "r") as fh: + custom_css = fh.read() if 'users' in config: users = config['users'] for form in config['forms']: @@ -116,7 +117,10 @@ class ScriptForm(object): self.httpd.daemon_threads = True self.log.info("Listening on %s:%s", listen_addr, listen_port) self.running = True - self.httpd.serve_forever() + try: + self.httpd.serve_forever() + except KeyboardInterrupt: + pass self.running = False def shutdown(self): @@ -138,41 +142,60 @@ class ScriptForm(object): # We need to spawn a new thread in which the server is shut down, # because doing it from the main thread blocks, since the server is # waiting for connections.. - thread.start_new_thread(t_shutdown, (self, )) + thread = threading.Thread(target=t_shutdown, args=(self,)) + thread.start() def main(): # pragma: no cover """ main method """ - usage = [ - sys.argv[0] + " [option] (--start|--stop) ", - " " + sys.argv[0] + " --generate-pw", - ] - parser = optparse.OptionParser(version="%%VERSION%%") - parser.set_usage('\n'.join(usage)) - - parser.add_option("-g", "--generate-pw", dest="generate_pw", - action="store_true", default=False, - help="Generate password") - parser.add_option("-p", "--port", dest="port", action="store", type="int", - default=8081, help="Port to listen on (default=8081)") - parser.add_option("-f", "--foreground", dest="foreground", - action="store_true", default=False, - help="Run in foreground (debugging)") - parser.add_option("-r", "--reload", dest="reload", action="store_true", - default=False, - help="Reload form config on every request (DEV)") - parser.add_option("--pid-file", dest="pid_file", action="store", - default=None, help="Pid file") - parser.add_option("--log-file", dest="log_file", action="store", - default=None, help="Log file") - parser.add_option("--start", dest="action_start", action="store_true", - default=None, help="Start daemon") - parser.add_option("--stop", dest="action_stop", action="store_true", - default=None, help="Stop daemon") - - (options, args) = parser.parse_args() + parser = argparse.ArgumentParser(description='My Application.') + parser.add_argument('--version', + action='version', + version='%(prog)s %%VERSION%%') + parser.add_argument('-g', '--generate-pw', + action='store_true', + default=False, + help='Generate password') + parser.add_argument('-p', '--port', + metavar='PORT', + dest='port', + type=int, + default=8081, + help='Port to listen on (default=8081)') + parser.add_argument('-f', '--foreground', + dest='foreground', + action='store_true', + default=False, + help='Run in foreground (debugging)') + parser.add_argument('-r', '--reload', + dest='reload', + action='store_true', + default=False, + help='Reload form config on every request (DEV)') + parser.add_argument('--pid-file', + metavar='PATH', + dest='pid_file', + type=str, + default=None, + help='Pid file') + parser.add_argument('--log-file', + metavar='PATH', + dest='log_file', + type=str, + default=None, + help='Log file') + parser.add_argument('--stop', + dest='action_stop', + action='store_true', + default=None, + help='Stop daemon') + parser.add_argument(dest='config', + metavar="CONFIG_FILE", + help="Path to form definition config", + ) + options = parser.parse_args() if options.generate_pw: # Generate a password for use in the `users` section @@ -181,45 +204,27 @@ def main(): # pragma: no cover if plain_pw != getpass.getpass('Repeat password: '): sys.stderr.write("Passwords do not match.\n") sys.exit(1) - sys.stdout.write(hashlib.sha256(plain_pw).hexdigest() + '\n') + sha = hashlib.sha256(plain_pw.encode('utf8')).hexdigest() + sys.stdout.write("{}\n".format(sha)) sys.exit(0) else: - if not options.action_stop and len(args) < 1: - parser.error("Insufficient number of arguments") - if not options.action_stop and not options.action_start: - options.action_start = True - - # If a form configuration was specified, change to that dir so we can - # find the job scripts and such. - if args: - path = os.path.dirname(args[0]) - if path: - os.chdir(path) - args[0] = os.path.basename(args[0]) + # Switch to dir of form definition configuration + formconfig_path = os.path.realpath(options.config) + os.chdir(os.path.dirname(formconfig_path)) + # Initialize daemon so we can start or stop it daemon = Daemon(options.pid_file, options.log_file, foreground=options.foreground) - log = logging.getLogger('MAIN') - try: - if options.action_start: - cache = not options.reload - scriptform_instance = ScriptForm(args[0], cache=cache) - daemon.register_shutdown_callback(scriptform_instance.shutdown) - daemon.start() - scriptform_instance.run(listen_port=options.port) - elif options.action_stop: - daemon.stop() - sys.exit(0) - except socket.error as err: - log.exception(err) - sys.stderr.write("Cannot bind to port {0}: {1}\n".format( - options.port, - str(err) - )) - sys.exit(2) - except Exception as err: - log.exception(err) - raise + + if options.action_stop: + daemon.stop() + sys.exit(0) + else: + cache = not options.reload + scriptform_instance = ScriptForm(formconfig_path, cache=cache) + daemon.register_shutdown_callback(scriptform_instance.shutdown) + daemon.start() + scriptform_instance.run(listen_port=options.port) if __name__ == "__main__": # pragma: no cover diff --git a/src/webapp.py b/src/webapp.py index edf7c55..bb0e5ab 100644 --- a/src/webapp.py +++ b/src/webapp.py @@ -3,7 +3,7 @@ The webapp part of Scriptform, which takes care of serving requests and handling them. """ -import cgi +import html import logging import tempfile import os @@ -209,14 +209,13 @@ class ScriptFormWebApp(RequestHandler): # If a 'users' element was present in the form configuration file, the # user must be authenticated. if form_config.users: - auth_header = self.headers.getheader("Authorization") + auth_header = self.headers.get("Authorization") if auth_header is not None: # Validate the username and password - auth_unpw = auth_header.split(' ', 1)[1] - username, password = base64.decodestring(auth_unpw).split(":", - 1) - pw_hash = hashlib.sha256(password).hexdigest() - + auth_unpw = auth_header.split(' ', 1)[1].encode('utf-8') + username, password = \ + base64.b64decode(auth_unpw).decode('utf-8').split(":", 1) + pw_hash = hashlib.sha256(password.encode('utf-8')).hexdigest() if username in form_config.users and \ pw_hash == form_config.users[username]: # Valid username and password. Return the username. @@ -414,13 +413,12 @@ class ScriptFormWebApp(RequestHandler): if field.filename == '': continue tmp_fname = tempfile.mktemp(prefix="scriptform_") - tmp_file = file(tmp_fname, 'w') - while True: - buf = field.file.read(1024 * 16) - if not buf: - break - tmp_file.write(buf) - tmp_file.close() + with open(tmp_fname, "wb") as tmp_file: + while True: + buf = field.file.read(1024 * 16) + if not buf: + break + tmp_file.write(buf) field.file.close() tmp_files.append(tmp_fname) # For later cleanup @@ -461,11 +459,11 @@ class ScriptFormWebApp(RequestHandler): # Ignore everything if we're doing raw output, since it's the # scripts responsibility. if result['exitcode'] != 0: - stderr = cgi.escape(result['stderr'].decode('utf8')) + stderr = html.escape(result['stderr'].decode('utf8')) msg = u'{0}'.format(stderr) else: if form_def.output == 'escaped': - stdout = cgi.escape(result['stdout'].decode('utf8')) + stdout = html.escape(result['stdout'].decode('utf8')) msg = u'
{0}
'.format(stdout) else: # Non-escaped output (html, usually) @@ -509,7 +507,7 @@ class ScriptFormWebApp(RequestHandler): if not os.path.exists(path): raise HTTPError(404, "Not found") - static_file = file(path, 'r') self.send_response(200) self.end_headers() - self.wfile.write(static_file.read()) + with open(path, "rb") as static_file: + self.wfile.write(static_file.read()) diff --git a/src/webserver.py b/src/webserver.py index b322a35..76d38b7 100644 --- a/src/webserver.py +++ b/src/webserver.py @@ -2,10 +2,10 @@ Basic web server / framework. """ -import BaseHTTPServer +from socketserver import ThreadingMixIn +from http.server import HTTPServer, BaseHTTPRequestHandler +import urllib.parse import cgi -import urlparse -from SocketServer import ThreadingMixIn class HTTPError(Exception): @@ -14,6 +14,9 @@ class HTTPError(Exception): etc. They are caught by the 'framework' and sent to the client's browser. """ def __init__(self, status_code, msg, headers=None): + assert isinstance(status_code, int) + assert isinstance(msg, str) + if headers is None: headers = {} self.status_code = status_code @@ -22,14 +25,13 @@ class HTTPError(Exception): Exception.__init__(self, status_code, msg, headers) -class ThreadedHTTPServer(ThreadingMixIn, BaseHTTPServer.HTTPServer): +class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): """ Base class for multithreaded HTTP servers. """ - pass -class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): +class RequestHandler(BaseHTTPRequestHandler): """ Basic web server request handler. Handles GET and POST requests. You should inherit from this class and implement h_ methods for handling requests. @@ -57,25 +59,13 @@ class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): environ={'REQUEST_METHOD': 'POST'}) self._call(self.path.strip('/'), params={'form_values': form_values}) - def do_OPTIONS(self): # pylint: disable=invalid-name - """ - Handle OPTIONS request and return CORS headers. - """ - self.send_response(200, 'ok') - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') - self.send_header('Access-Control-Allow-Headers', 'X-Requested-With') - self.send_header('Access-Control-Allow-Headers', 'Content-Type, ' - 'Authorization') - self.end_headers() - def _parse(self, reqinfo): """ Parse information from a request. """ - url_comp = urlparse.urlsplit(reqinfo) + url_comp = urllib.parse.urlsplit(reqinfo) path = url_comp.path - query_vars = urlparse.parse_qs(url_comp.query) + query_vars = urllib.parse.parse_qs(url_comp.query) # Only return the first value of each query var. E.g. for # "?foo=1&foo=2" return '1'. var_values = dict([(k, v[0]) for k, v in query_vars.items()]) @@ -116,7 +106,7 @@ class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): self.send_header(header_k, header_v) self.end_headers() self.wfile.write("Error {0}: {1}".format(err.status_code, - err.msg)) + err.msg).encode('utf-8')) self.wfile.flush() return False except Exception as err: diff --git a/test/test.py b/test/test.py index 81f12c5..9c12c9f 100644 --- a/test/test.py +++ b/test/test.py @@ -1,14 +1,12 @@ import logging import sys import unittest -from StringIO import StringIO import json import os import copy -import thread +import threading import time import requests -import StringIO import re @@ -51,22 +49,24 @@ class FormConfigTestCase(unittest.TestCase): fc = sf.get_form_config() fd = fc.get_form_def('test_store') res = runscript.run_script(fd, {}, {}) - self.assertEquals(res['exitcode'], 33) - self.assertTrue('stdout' in res['stdout']) - self.assertTrue('stderr' in res['stderr']) + self.assertEqual(res['exitcode'], 33) + self.assertTrue(b'stdout' in res['stdout']) + self.assertTrue(b'stderr' in res['stderr']) def testCallbackRaw(self): """Test a callback that returns raw output""" sf = scriptform.ScriptForm('test_formconfig_callback.json') fc = sf.get_form_config() fd = fc.get_form_def('test_raw') - stdout = file('tmp_stdout', 'w+') # can't use StringIO - stderr = file('tmp_stderr', 'w+') + stdout = open('tmp_stdout', 'w+') # can't use StringIO + stderr = open('tmp_stderr', 'w+') exitcode = runscript.run_script(fd, {}, {}, stdout, stderr) stdout.seek(0) stderr.seek(0) self.assertTrue(exitcode == 33) self.assertTrue('stdout' in stdout.read()) + stdout.close() + stderr.close() def testCallbackMissingParams(self): """ @@ -115,7 +115,7 @@ class FormDefinitionTest(unittest.TestCase): form_values = {"val_string": "1234"} errors, values = fd.validate(form_values) self.assertNotIn('val_string', errors) - self.assertEquals(values['val_string'], "1234") + self.assertEqual(values['val_string'], "1234") def testValidateIntegerInvalid(self): fd = self.fc.get_form_def('test_val_integer') @@ -143,7 +143,7 @@ class FormDefinitionTest(unittest.TestCase): form_values = {"val_integer": 6} errors, values = fd.validate(form_values) self.assertNotIn('val_integer', errors) - self.assertEquals(values['val_integer'], 6) + self.assertEqual(values['val_integer'], 6) def testValidateFloatInvalid(self): fd = self.fc.get_form_def('test_val_float') @@ -171,7 +171,7 @@ class FormDefinitionTest(unittest.TestCase): form_values = {"val_float": 2.29} errors, values = fd.validate(form_values) self.assertNotIn('val_float', errors) - self.assertEquals(values['val_float'], 2.29) + self.assertEqual(values['val_float'], 2.29) def testValidateDateInvalid(self): fd = self.fc.get_form_def('test_val_date') @@ -200,14 +200,14 @@ class FormDefinitionTest(unittest.TestCase): form_values = {"val_date": '2015-03-03'} errors, values = fd.validate(form_values) self.assertNotIn('val_date', errors) - self.assertEquals(values['val_date'], datetime.date(2015, 3, 3)) + self.assertEqual(values['val_date'], datetime.date(2015, 3, 3)) def testValidateSelectValue(self): fd = self.fc.get_form_def('test_val_select') form_values = {"val_select": 'option_a'} errors, values = fd.validate(form_values) self.assertNotIn('val_select', errors) - self.assertEquals(values['val_select'], 'option_a') + self.assertEqual(values['val_select'], 'option_a') def testValidateSelectInvalid(self): fd = self.fc.get_form_def('test_val_select') @@ -221,14 +221,14 @@ class FormDefinitionTest(unittest.TestCase): form_values = {"val_checkbox": 'on'} errors, values = fd.validate(form_values) self.assertNotIn('val_checkbox', errors) - self.assertEquals(values['val_checkbox'], 'on') + self.assertEqual(values['val_checkbox'], 'on') def testValidateCheckboxDefaultOn(self): fd = self.fc.get_form_def('test_val_checkbox_on') form_values = {"val_checkbox_on": 'off'} errors, values = fd.validate(form_values) self.assertNotIn('val_checkbox_on', errors) - self.assertEquals(values['val_checkbox_on'], 'off') + self.assertEqual(values['val_checkbox_on'], 'off') def testValidateCheckboxInvalid(self): fd = self.fc.get_form_def('test_val_checkbox') @@ -280,22 +280,27 @@ class WebAppTest(unittest.TestCase): cls.auth_admin = requests.auth.HTTPBasicAuth('admin', 'admin') cls.auth_user = requests.auth.HTTPBasicAuth('user', 'user') + # Run the server in a thread, so we can execute the tests in the main + # program. def server_thread(sf): sf.run(listen_port=8002) cls.sf = scriptform.ScriptForm('test_webapp.json') - thread.start_new_thread(server_thread, (cls.sf, )) - # Wait until the webserver is ready + + thread = threading.Thread(target=server_thread, args=(cls.sf,)) + thread.start() + while True: time.sleep(0.1) - if cls.sf.running: + if cls.sf.running is True: break @classmethod def tearDownClass(cls): + # Shut down the webserver and wait until it has shut down. cls.sf.shutdown() while True: time.sleep(0.1) - if not cls.sf.running: + if cls.sf.running is False: break def testError404(self): @@ -356,25 +361,25 @@ class WebAppTest(unittest.TestCase): } import random - f = file('data.csv', 'w') - for i in range(1024): - f.write(chr(random.randint(0, 255))) - f.close() - - files = {'file': open('data.csv', 'rb')} - r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) - - self.assertIn('string=12345', r.text) - self.assertIn('integer=12', r.text) - self.assertIn('float=0.6', r.text) - self.assertIn('date=2015-01-02', r.text) - self.assertIn('text=1234567890', r.text) - self.assertIn('password=12345', r.text) - self.assertIn('radio=One', r.text) - self.assertIn('checkbox=on', r.text) - self.assertIn('select=option_a', r.text) - - os.unlink('data.csv') + with open('data.csv', 'w') as fh: + for i in range(1024): + fh.write(chr(random.randint(0, 255))) + + with open('data.csv', 'rb') as fh: + files = {'file': fh} + r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) + + self.assertIn('string=12345', r.text) + self.assertIn('integer=12', r.text) + self.assertIn('float=0.6', r.text) + self.assertIn('date=2015-01-02', r.text) + self.assertIn('text=1234567890', r.text) + self.assertIn('password=12345', r.text) + self.assertIn('radio=One', r.text) + self.assertIn('checkbox=on', r.text) + self.assertIn('select=option_a', r.text) + + os.unlink('data.csv') def testValidateIncorrectData(self): data = { @@ -391,26 +396,26 @@ class WebAppTest(unittest.TestCase): } import random - f = file('data.txt', 'w') - for i in range(1024): - f.write(chr(random.randint(0, 255))) - f.close() - - files = {'file': open('data.txt', 'rb')} - r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) - - self.assertIn('Maximum length is 7', r.text) - self.assertIn('Minimum value is 10', r.text) - self.assertIn('Maximum value is 1.0', r.text) - self.assertIn('Maximum value is 2015-02-01', r.text) - self.assertIn('Invalid value for radio button: Ten', r.text) - self.assertIn('Minimum length is 10', r.text) - self.assertIn('Minimum length is 5', r.text) - self.assertIn('Only file types allowed: csv', r.text) - self.assertIn('Invalid value for radio button', r.text) - self.assertIn('Invalid value for dropdown', r.text) - - os.unlink('data.txt') + with open('data.txt', 'w') as fh: + for i in range(1024): + fh.write(chr(random.randint(0, 255))) + + with open('data.txt', 'rb') as fh: + files = {'file': fh} + r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) + + self.assertIn('Maximum length is 7', r.text) + self.assertIn('Minimum value is 10', r.text) + self.assertIn('Maximum value is 1.0', r.text) + self.assertIn('Maximum value is 2015-02-01', r.text) + self.assertIn('Invalid value for radio button: Ten', r.text) + self.assertIn('Minimum length is 10', r.text) + self.assertIn('Minimum length is 5', r.text) + self.assertIn('Only file types allowed: csv', r.text) + self.assertIn('Invalid value for radio button', r.text) + self.assertIn('Invalid value for dropdown', r.text) + + os.unlink('data.txt') def testValidateRefill(self): """ @@ -431,23 +436,23 @@ class WebAppTest(unittest.TestCase): } import random - f = file('data.txt', 'w') - for i in range(1024): - f.write(chr(random.randint(0, 255))) - f.close() - - files = {'file': open('data.txt', 'rb')} - r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) - self.assertIn('value="123"', r.text) - self.assertIn('value="12"', r.text) - self.assertIn('value="0.6"', r.text) - self.assertIn('value="2015-01-02"', r.text) - self.assertIn('>1234567890<', r.text) - self.assertIn('value="12345"', r.text) - self.assertIn('value="on"', r.text) - self.assertIn('selected>Option B', r.text) - - os.unlink('data.txt') + with open('data.txt', 'w') as fh: + for i in range(1024): + fh.write(chr(random.randint(0, 255))) + + with open ('data.txt', 'rb') as fh: + files = {'file': fh} + r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) + self.assertIn('value="123"', r.text) + self.assertIn('value="12"', r.text) + self.assertIn('value="0.6"', r.text) + self.assertIn('value="2015-01-02"', r.text) + self.assertIn('>1234567890<', r.text) + self.assertIn('value="12345"', r.text) + self.assertIn('value="on"', r.text) + self.assertIn('selected>Option B', r.text) + + os.unlink('data.txt') def testOutputEscaped(self): """Form with 'escaped' output should have HTML entities escaped""" @@ -476,36 +481,36 @@ class WebAppTest(unittest.TestCase): def testUpload(self): import random - f = file('data.raw', 'w') - for i in range(1024): - f.write(chr(random.randint(0, 255))) - f.close() + with open('data.raw', 'w') as fh: + fh.write(chr(random.randint(0, 255))) data = { "form_name": "upload" } - files = {'file': open('data.raw', 'rb')} - r = requests.post("http://localhost:8002/submit", files=files, data=data, auth=self.auth_user) - self.assertIn('SAME', r.text) - os.unlink('data.raw') + with open('data.raw', 'rb') as fh: + files = {'file': fh} + r = requests.post("http://localhost:8002/submit", files=files, data=data, auth=self.auth_user) + self.assertIn('SAME', r.text) + os.unlink('data.raw') def testStaticValid(self): r = requests.get("http://localhost:8002/static?fname=ssh_server.png", auth=self.auth_user) - self.assertEquals(r.status_code, 200) + self.assertEqual(r.status_code, 200) f_served = b'' for c in r.iter_content(): f_served += c - f_orig = file('static/ssh_server.png', 'rb').read() - self.assertEquals(f_orig, f_served) + with open('static/ssh_server.png', 'rb')as fh: + f_orig = fh.read() + self.assertEqual(f_orig, f_served) def testStaticInvalidFilename(self): r = requests.get("http://localhost:8002/static?fname=../../ssh_server.png", auth=self.auth_user) - self.assertEquals(r.status_code, 403) + self.assertEqual(r.status_code, 403) def testStaticInvalidNotFound(self): r = requests.get("http://localhost:8002/static?fname=nosuchfile.png", auth=self.auth_user) - self.assertEquals(r.status_code, 404) + self.assertEqual(r.status_code, 404) def testHiddenField(self): r = requests.get('http://localhost:8002/form?form_name=hidden_field', auth=self.auth_user) @@ -526,14 +531,18 @@ class WebAppSingleTest(unittest.TestCase): """ @classmethod def setUpClass(cls): + # Run the server in a thread, so we can execute the tests in the main + # program. def server_thread(sf): sf.run(listen_port=8002) cls.sf = scriptform.ScriptForm('test_webapp_singleform.json') - thread.start_new_thread(server_thread, (cls.sf, )) - # Wait until the webserver is ready + + thread = threading.Thread(target=server_thread, args=(cls.sf,)) + thread.start() + while True: time.sleep(0.1) - if cls.sf.running: + if cls.sf.running is True: break @classmethod @@ -555,7 +564,7 @@ class WebAppSingleTest(unittest.TestCase): """ """ r = requests.get("http://localhost:8002/static?fname=nosuchfile.png") - self.assertEquals(r.status_code, 501) + self.assertEqual(r.status_code, 501) if __name__ == '__main__': @@ -575,11 +584,11 @@ if __name__ == '__main__': cov.stop() cov.save() - print cov.report() + print(cov.report()) try: - print cov.html_report() - except coverage.misc.CoverageException, e: - if "Couldn't find static file 'jquery.hotkeys.js'" in e.message: + print(cov.html_report()) + except coverage.misc.CoverageException as err: + if "Couldn't find static file 'jquery.hotkeys.js'" in err.message: pass else: raise