Merge branch 'py3'

master
Ferry Boender 4 years ago
commit 40badfe2d6
  1. 2
      build.sla
  2. 2
      examples/auth/auth.json
  3. 4
      examples/simple/job_import.sh
  4. 9
      src/daemon.py
  5. 1
      src/formconfig.py
  6. 5
      src/formdefinition.py
  7. 133
      src/scriptform.py
  8. 22
      src/webapp.py
  9. 32
      src/webserver.py
  10. 107
      test/test.py

@ -22,7 +22,7 @@ test () {
# Code quality linting (pylint) # Code quality linting (pylint)
cd $ROOTDIR cd $ROOTDIR
cd src && pylint --reports=n -dR -d star-args -d no-member *.py || true cd src && pylint --reports=n -dR -d subprocess-popen-preexec-fn -d invalid-name -d star-args -d no-member *.py || true
cd $ROOTDIR cd $ROOTDIR
} }

@ -8,7 +8,7 @@
{ {
"name": "do_nothing", "name": "do_nothing",
"title": "Test form", "title": "Test form",
"description": "You should only see this if you've entered the correct password", "description": "All users that logged in should be able to see this.",
"submit_title": "Do nothing", "submit_title": "Do nothing",
"script": "job_do_nothing.sh", "script": "job_do_nothing.sh",
"fields": [ "fields": [

@ -8,3 +8,7 @@ echo "This is what would be executed if this wasn't a fake script:"
echo echo
echo " echo 'DROP DATABASE $target_db' | $MYSQL" echo " echo 'DROP DATABASE $target_db' | $MYSQL"
echo " $MYSQL ${target_db} < ${sql_file}" echo " $MYSQL ${target_db} < ${sql_file}"
echo
echo "The uploaded file was $(stat --printf="%s" $sql_file) bytes"
echo "The (binary) md5 hash of the uploaded file is: $(md5sum -b $sql_file | cut -d " " -f1)"

@ -15,7 +15,6 @@ class DaemonError(Exception):
""" """
Default error for Daemon class. Default error for Daemon class.
""" """
pass
class Daemon(object): # pragma: no cover class Daemon(object): # pragma: no cover
@ -97,7 +96,8 @@ class Daemon(object): # pragma: no cover
return None return None
try: try:
pid = int(file(self.pid_file, 'r').read().strip()) with open(self.pid_file, "r") as fh:
pid = int(fh.read().strip())
except ValueError: except ValueError:
return None return None
@ -137,9 +137,8 @@ class Daemon(object): # pragma: no cover
pid = os.fork() pid = os.fork()
if pid > 0: if pid > 0:
self.log.info("PID = %s", pid) self.log.info("PID = %s", pid)
pidfile = file(self.pid_file, 'w') with open(self.pid_file, "w") as fh:
pidfile.write(str(pid)) fh.write(str(pid))
pidfile.close()
sys.exit(0) # End parent sys.exit(0) # End parent
atexit.register(self._cleanup) atexit.register(self._cleanup)

@ -13,7 +13,6 @@ class FormConfigError(Exception):
""" """
Default error for FormConfig errors Default error for FormConfig errors
""" """
pass
class FormConfig(object): class FormConfig(object):

@ -10,8 +10,9 @@ import runscript
class ValidationError(Exception): class ValidationError(Exception):
"""Default exception for Validation errors""" """
pass Default exception for Validation errors
"""
class FormDefinition(object): class FormDefinition(object):

@ -1,4 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
@ -6,13 +6,12 @@ Main ScriptForm program
""" """
import sys import sys
import optparse import argparse
import os import os
import json import json
import logging import logging
import thread import threading
import hashlib import hashlib
import socket
if hasattr(sys, 'dont_write_bytecode'): if hasattr(sys, 'dont_write_bytecode'):
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
@ -51,7 +50,8 @@ class ScriptForm(object):
if self.cache and self.form_config_singleton is not None: if self.cache and self.form_config_singleton is not None:
return self.form_config_singleton return self.form_config_singleton
file_contents = file(self.config_file, 'r').read() with open(self.config_file, "r") as fh:
file_contents = fh.read()
try: try:
config = json.loads(file_contents) config = json.loads(file_contents)
except ValueError as err: except ValueError as err:
@ -67,7 +67,8 @@ class ScriptForm(object):
if 'static_dir' in config: if 'static_dir' in config:
static_dir = config['static_dir'] static_dir = config['static_dir']
if 'custom_css' in config: if 'custom_css' in config:
custom_css = file(config['custom_css'], 'r').read() with open(config["custom_css"], "r") as fh:
custom_css = fh.read()
if 'users' in config: if 'users' in config:
users = config['users'] users = config['users']
for form in config['forms']: for form in config['forms']:
@ -116,7 +117,10 @@ class ScriptForm(object):
self.httpd.daemon_threads = True self.httpd.daemon_threads = True
self.log.info("Listening on %s:%s", listen_addr, listen_port) self.log.info("Listening on %s:%s", listen_addr, listen_port)
self.running = True self.running = True
try:
self.httpd.serve_forever() self.httpd.serve_forever()
except KeyboardInterrupt:
pass
self.running = False self.running = False
def shutdown(self): def shutdown(self):
@ -138,41 +142,60 @@ class ScriptForm(object):
# We need to spawn a new thread in which the server is shut down, # We need to spawn a new thread in which the server is shut down,
# because doing it from the main thread blocks, since the server is # because doing it from the main thread blocks, since the server is
# waiting for connections.. # waiting for connections..
thread.start_new_thread(t_shutdown, (self, )) thread = threading.Thread(target=t_shutdown, args=(self,))
thread.start()
def main(): # pragma: no cover def main(): # pragma: no cover
""" """
main method main method
""" """
usage = [ parser = argparse.ArgumentParser(description='My Application.')
sys.argv[0] + " [option] (--start|--stop) <form_definition.json>", parser.add_argument('--version',
" " + sys.argv[0] + " --generate-pw", action='version',
] version='%(prog)s %%VERSION%%')
parser = optparse.OptionParser(version="%%VERSION%%") parser.add_argument('-g', '--generate-pw',
parser.set_usage('\n'.join(usage)) action='store_true',
default=False,
parser.add_option("-g", "--generate-pw", dest="generate_pw", help='Generate password')
action="store_true", default=False, parser.add_argument('-p', '--port',
help="Generate password") metavar='PORT',
parser.add_option("-p", "--port", dest="port", action="store", type="int", dest='port',
default=8081, help="Port to listen on (default=8081)") type=int,
parser.add_option("-f", "--foreground", dest="foreground", default=8081,
action="store_true", default=False, help='Port to listen on (default=8081)')
help="Run in foreground (debugging)") parser.add_argument('-f', '--foreground',
parser.add_option("-r", "--reload", dest="reload", action="store_true", dest='foreground',
action='store_true',
default=False,
help='Run in foreground (debugging)')
parser.add_argument('-r', '--reload',
dest='reload',
action='store_true',
default=False, default=False,
help="Reload form config on every request (DEV)") help='Reload form config on every request (DEV)')
parser.add_option("--pid-file", dest="pid_file", action="store", parser.add_argument('--pid-file',
default=None, help="Pid file") metavar='PATH',
parser.add_option("--log-file", dest="log_file", action="store", dest='pid_file',
default=None, help="Log file") type=str,
parser.add_option("--start", dest="action_start", action="store_true", default=None,
default=None, help="Start daemon") help='Pid file')
parser.add_option("--stop", dest="action_stop", action="store_true", parser.add_argument('--log-file',
default=None, help="Stop daemon") metavar='PATH',
dest='log_file',
(options, args) = parser.parse_args() type=str,
default=None,
help='Log file')
parser.add_argument('--stop',
dest='action_stop',
action='store_true',
default=None,
help='Stop daemon')
parser.add_argument(dest='config',
metavar="CONFIG_FILE",
help="Path to form definition config",
)
options = parser.parse_args()
if options.generate_pw: if options.generate_pw:
# Generate a password for use in the `users` section # Generate a password for use in the `users` section
@ -181,45 +204,27 @@ def main(): # pragma: no cover
if plain_pw != getpass.getpass('Repeat password: '): if plain_pw != getpass.getpass('Repeat password: '):
sys.stderr.write("Passwords do not match.\n") sys.stderr.write("Passwords do not match.\n")
sys.exit(1) sys.exit(1)
sys.stdout.write(hashlib.sha256(plain_pw).hexdigest() + '\n') sha = hashlib.sha256(plain_pw.encode('utf8')).hexdigest()
sys.stdout.write("{}\n".format(sha))
sys.exit(0) sys.exit(0)
else: else:
if not options.action_stop and len(args) < 1: # Switch to dir of form definition configuration
parser.error("Insufficient number of arguments") formconfig_path = os.path.realpath(options.config)
if not options.action_stop and not options.action_start: os.chdir(os.path.dirname(formconfig_path))
options.action_start = True
# If a form configuration was specified, change to that dir so we can
# find the job scripts and such.
if args:
path = os.path.dirname(args[0])
if path:
os.chdir(path)
args[0] = os.path.basename(args[0])
# Initialize daemon so we can start or stop it
daemon = Daemon(options.pid_file, options.log_file, daemon = Daemon(options.pid_file, options.log_file,
foreground=options.foreground) foreground=options.foreground)
log = logging.getLogger('MAIN')
try: if options.action_stop:
if options.action_start: daemon.stop()
sys.exit(0)
else:
cache = not options.reload cache = not options.reload
scriptform_instance = ScriptForm(args[0], cache=cache) scriptform_instance = ScriptForm(formconfig_path, cache=cache)
daemon.register_shutdown_callback(scriptform_instance.shutdown) daemon.register_shutdown_callback(scriptform_instance.shutdown)
daemon.start() daemon.start()
scriptform_instance.run(listen_port=options.port) scriptform_instance.run(listen_port=options.port)
elif options.action_stop:
daemon.stop()
sys.exit(0)
except socket.error as err:
log.exception(err)
sys.stderr.write("Cannot bind to port {0}: {1}\n".format(
options.port,
str(err)
))
sys.exit(2)
except Exception as err:
log.exception(err)
raise
if __name__ == "__main__": # pragma: no cover if __name__ == "__main__": # pragma: no cover

@ -3,7 +3,7 @@ The webapp part of Scriptform, which takes care of serving requests and
handling them. handling them.
""" """
import cgi import html
import logging import logging
import tempfile import tempfile
import os import os
@ -209,14 +209,13 @@ class ScriptFormWebApp(RequestHandler):
# If a 'users' element was present in the form configuration file, the # If a 'users' element was present in the form configuration file, the
# user must be authenticated. # user must be authenticated.
if form_config.users: if form_config.users:
auth_header = self.headers.getheader("Authorization") auth_header = self.headers.get("Authorization")
if auth_header is not None: if auth_header is not None:
# Validate the username and password # Validate the username and password
auth_unpw = auth_header.split(' ', 1)[1] auth_unpw = auth_header.split(' ', 1)[1].encode('utf-8')
username, password = base64.decodestring(auth_unpw).split(":", username, password = \
1) base64.b64decode(auth_unpw).decode('utf-8').split(":", 1)
pw_hash = hashlib.sha256(password).hexdigest() pw_hash = hashlib.sha256(password.encode('utf-8')).hexdigest()
if username in form_config.users and \ if username in form_config.users and \
pw_hash == form_config.users[username]: pw_hash == form_config.users[username]:
# Valid username and password. Return the username. # Valid username and password. Return the username.
@ -414,13 +413,12 @@ class ScriptFormWebApp(RequestHandler):
if field.filename == '': if field.filename == '':
continue continue
tmp_fname = tempfile.mktemp(prefix="scriptform_") tmp_fname = tempfile.mktemp(prefix="scriptform_")
tmp_file = file(tmp_fname, 'w') with open(tmp_fname, "wb") as tmp_file:
while True: while True:
buf = field.file.read(1024 * 16) buf = field.file.read(1024 * 16)
if not buf: if not buf:
break break
tmp_file.write(buf) tmp_file.write(buf)
tmp_file.close()
field.file.close() field.file.close()
tmp_files.append(tmp_fname) # For later cleanup tmp_files.append(tmp_fname) # For later cleanup
@ -461,11 +459,11 @@ class ScriptFormWebApp(RequestHandler):
# Ignore everything if we're doing raw output, since it's the # Ignore everything if we're doing raw output, since it's the
# scripts responsibility. # scripts responsibility.
if result['exitcode'] != 0: if result['exitcode'] != 0:
stderr = cgi.escape(result['stderr'].decode('utf8')) stderr = html.escape(result['stderr'].decode('utf8'))
msg = u'<span class="error">{0}</span>'.format(stderr) msg = u'<span class="error">{0}</span>'.format(stderr)
else: else:
if form_def.output == 'escaped': if form_def.output == 'escaped':
stdout = cgi.escape(result['stdout'].decode('utf8')) stdout = html.escape(result['stdout'].decode('utf8'))
msg = u'<pre>{0}</pre>'.format(stdout) msg = u'<pre>{0}</pre>'.format(stdout)
else: else:
# Non-escaped output (html, usually) # Non-escaped output (html, usually)
@ -509,7 +507,7 @@ class ScriptFormWebApp(RequestHandler):
if not os.path.exists(path): if not os.path.exists(path):
raise HTTPError(404, "Not found") raise HTTPError(404, "Not found")
static_file = file(path, 'r')
self.send_response(200) self.send_response(200)
self.end_headers() self.end_headers()
with open(path, "rb") as static_file:
self.wfile.write(static_file.read()) self.wfile.write(static_file.read())

@ -2,10 +2,10 @@
Basic web server / framework. Basic web server / framework.
""" """
import BaseHTTPServer from socketserver import ThreadingMixIn
from http.server import HTTPServer, BaseHTTPRequestHandler
import urllib.parse
import cgi import cgi
import urlparse
from SocketServer import ThreadingMixIn
class HTTPError(Exception): class HTTPError(Exception):
@ -14,6 +14,9 @@ class HTTPError(Exception):
etc. They are caught by the 'framework' and sent to the client's browser. etc. They are caught by the 'framework' and sent to the client's browser.
""" """
def __init__(self, status_code, msg, headers=None): def __init__(self, status_code, msg, headers=None):
assert isinstance(status_code, int)
assert isinstance(msg, str)
if headers is None: if headers is None:
headers = {} headers = {}
self.status_code = status_code self.status_code = status_code
@ -22,14 +25,13 @@ class HTTPError(Exception):
Exception.__init__(self, status_code, msg, headers) Exception.__init__(self, status_code, msg, headers)
class ThreadedHTTPServer(ThreadingMixIn, BaseHTTPServer.HTTPServer): class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
""" """
Base class for multithreaded HTTP servers. Base class for multithreaded HTTP servers.
""" """
pass
class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
""" """
Basic web server request handler. Handles GET and POST requests. You should Basic web server request handler. Handles GET and POST requests. You should
inherit from this class and implement h_ methods for handling requests. inherit from this class and implement h_ methods for handling requests.
@ -57,25 +59,13 @@ class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
environ={'REQUEST_METHOD': 'POST'}) environ={'REQUEST_METHOD': 'POST'})
self._call(self.path.strip('/'), params={'form_values': form_values}) self._call(self.path.strip('/'), params={'form_values': form_values})
def do_OPTIONS(self): # pylint: disable=invalid-name
"""
Handle OPTIONS request and return CORS headers.
"""
self.send_response(200, 'ok')
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'X-Requested-With')
self.send_header('Access-Control-Allow-Headers', 'Content-Type, '
'Authorization')
self.end_headers()
def _parse(self, reqinfo): def _parse(self, reqinfo):
""" """
Parse information from a request. Parse information from a request.
""" """
url_comp = urlparse.urlsplit(reqinfo) url_comp = urllib.parse.urlsplit(reqinfo)
path = url_comp.path path = url_comp.path
query_vars = urlparse.parse_qs(url_comp.query) query_vars = urllib.parse.parse_qs(url_comp.query)
# Only return the first value of each query var. E.g. for # Only return the first value of each query var. E.g. for
# "?foo=1&foo=2" return '1'. # "?foo=1&foo=2" return '1'.
var_values = dict([(k, v[0]) for k, v in query_vars.items()]) var_values = dict([(k, v[0]) for k, v in query_vars.items()])
@ -116,7 +106,7 @@ class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
self.send_header(header_k, header_v) self.send_header(header_k, header_v)
self.end_headers() self.end_headers()
self.wfile.write("Error {0}: {1}".format(err.status_code, self.wfile.write("Error {0}: {1}".format(err.status_code,
err.msg)) err.msg).encode('utf-8'))
self.wfile.flush() self.wfile.flush()
return False return False
except Exception as err: except Exception as err:

@ -1,14 +1,12 @@
import logging import logging
import sys import sys
import unittest import unittest
from StringIO import StringIO
import json import json
import os import os
import copy import copy
import thread import threading
import time import time
import requests import requests
import StringIO
import re import re
@ -51,22 +49,24 @@ class FormConfigTestCase(unittest.TestCase):
fc = sf.get_form_config() fc = sf.get_form_config()
fd = fc.get_form_def('test_store') fd = fc.get_form_def('test_store')
res = runscript.run_script(fd, {}, {}) res = runscript.run_script(fd, {}, {})
self.assertEquals(res['exitcode'], 33) self.assertEqual(res['exitcode'], 33)
self.assertTrue('stdout' in res['stdout']) self.assertTrue(b'stdout' in res['stdout'])
self.assertTrue('stderr' in res['stderr']) self.assertTrue(b'stderr' in res['stderr'])
def testCallbackRaw(self): def testCallbackRaw(self):
"""Test a callback that returns raw output""" """Test a callback that returns raw output"""
sf = scriptform.ScriptForm('test_formconfig_callback.json') sf = scriptform.ScriptForm('test_formconfig_callback.json')
fc = sf.get_form_config() fc = sf.get_form_config()
fd = fc.get_form_def('test_raw') fd = fc.get_form_def('test_raw')
stdout = file('tmp_stdout', 'w+') # can't use StringIO stdout = open('tmp_stdout', 'w+') # can't use StringIO
stderr = file('tmp_stderr', 'w+') stderr = open('tmp_stderr', 'w+')
exitcode = runscript.run_script(fd, {}, {}, stdout, stderr) exitcode = runscript.run_script(fd, {}, {}, stdout, stderr)
stdout.seek(0) stdout.seek(0)
stderr.seek(0) stderr.seek(0)
self.assertTrue(exitcode == 33) self.assertTrue(exitcode == 33)
self.assertTrue('stdout' in stdout.read()) self.assertTrue('stdout' in stdout.read())
stdout.close()
stderr.close()
def testCallbackMissingParams(self): def testCallbackMissingParams(self):
""" """
@ -115,7 +115,7 @@ class FormDefinitionTest(unittest.TestCase):
form_values = {"val_string": "1234"} form_values = {"val_string": "1234"}
errors, values = fd.validate(form_values) errors, values = fd.validate(form_values)
self.assertNotIn('val_string', errors) self.assertNotIn('val_string', errors)
self.assertEquals(values['val_string'], "1234") self.assertEqual(values['val_string'], "1234")
def testValidateIntegerInvalid(self): def testValidateIntegerInvalid(self):
fd = self.fc.get_form_def('test_val_integer') fd = self.fc.get_form_def('test_val_integer')
@ -143,7 +143,7 @@ class FormDefinitionTest(unittest.TestCase):
form_values = {"val_integer": 6} form_values = {"val_integer": 6}
errors, values = fd.validate(form_values) errors, values = fd.validate(form_values)
self.assertNotIn('val_integer', errors) self.assertNotIn('val_integer', errors)
self.assertEquals(values['val_integer'], 6) self.assertEqual(values['val_integer'], 6)
def testValidateFloatInvalid(self): def testValidateFloatInvalid(self):
fd = self.fc.get_form_def('test_val_float') fd = self.fc.get_form_def('test_val_float')
@ -171,7 +171,7 @@ class FormDefinitionTest(unittest.TestCase):
form_values = {"val_float": 2.29} form_values = {"val_float": 2.29}
errors, values = fd.validate(form_values) errors, values = fd.validate(form_values)
self.assertNotIn('val_float', errors) self.assertNotIn('val_float', errors)
self.assertEquals(values['val_float'], 2.29) self.assertEqual(values['val_float'], 2.29)
def testValidateDateInvalid(self): def testValidateDateInvalid(self):
fd = self.fc.get_form_def('test_val_date') fd = self.fc.get_form_def('test_val_date')
@ -200,14 +200,14 @@ class FormDefinitionTest(unittest.TestCase):
form_values = {"val_date": '2015-03-03'} form_values = {"val_date": '2015-03-03'}
errors, values = fd.validate(form_values) errors, values = fd.validate(form_values)
self.assertNotIn('val_date', errors) self.assertNotIn('val_date', errors)
self.assertEquals(values['val_date'], datetime.date(2015, 3, 3)) self.assertEqual(values['val_date'], datetime.date(2015, 3, 3))
def testValidateSelectValue(self): def testValidateSelectValue(self):
fd = self.fc.get_form_def('test_val_select') fd = self.fc.get_form_def('test_val_select')
form_values = {"val_select": 'option_a'} form_values = {"val_select": 'option_a'}
errors, values = fd.validate(form_values) errors, values = fd.validate(form_values)
self.assertNotIn('val_select', errors) self.assertNotIn('val_select', errors)
self.assertEquals(values['val_select'], 'option_a') self.assertEqual(values['val_select'], 'option_a')
def testValidateSelectInvalid(self): def testValidateSelectInvalid(self):
fd = self.fc.get_form_def('test_val_select') fd = self.fc.get_form_def('test_val_select')
@ -221,14 +221,14 @@ class FormDefinitionTest(unittest.TestCase):
form_values = {"val_checkbox": 'on'} form_values = {"val_checkbox": 'on'}
errors, values = fd.validate(form_values) errors, values = fd.validate(form_values)
self.assertNotIn('val_checkbox', errors) self.assertNotIn('val_checkbox', errors)
self.assertEquals(values['val_checkbox'], 'on') self.assertEqual(values['val_checkbox'], 'on')
def testValidateCheckboxDefaultOn(self): def testValidateCheckboxDefaultOn(self):
fd = self.fc.get_form_def('test_val_checkbox_on') fd = self.fc.get_form_def('test_val_checkbox_on')
form_values = {"val_checkbox_on": 'off'} form_values = {"val_checkbox_on": 'off'}
errors, values = fd.validate(form_values) errors, values = fd.validate(form_values)
self.assertNotIn('val_checkbox_on', errors) self.assertNotIn('val_checkbox_on', errors)
self.assertEquals(values['val_checkbox_on'], 'off') self.assertEqual(values['val_checkbox_on'], 'off')
def testValidateCheckboxInvalid(self): def testValidateCheckboxInvalid(self):
fd = self.fc.get_form_def('test_val_checkbox') fd = self.fc.get_form_def('test_val_checkbox')
@ -280,22 +280,27 @@ class WebAppTest(unittest.TestCase):
cls.auth_admin = requests.auth.HTTPBasicAuth('admin', 'admin') cls.auth_admin = requests.auth.HTTPBasicAuth('admin', 'admin')
cls.auth_user = requests.auth.HTTPBasicAuth('user', 'user') cls.auth_user = requests.auth.HTTPBasicAuth('user', 'user')
# Run the server in a thread, so we can execute the tests in the main
# program.
def server_thread(sf): def server_thread(sf):
sf.run(listen_port=8002) sf.run(listen_port=8002)
cls.sf = scriptform.ScriptForm('test_webapp.json') cls.sf = scriptform.ScriptForm('test_webapp.json')
thread.start_new_thread(server_thread, (cls.sf, ))
# Wait until the webserver is ready thread = threading.Thread(target=server_thread, args=(cls.sf,))
thread.start()
while True: while True:
time.sleep(0.1) time.sleep(0.1)
if cls.sf.running: if cls.sf.running is True:
break break
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
# Shut down the webserver and wait until it has shut down.
cls.sf.shutdown() cls.sf.shutdown()
while True: while True:
time.sleep(0.1) time.sleep(0.1)
if not cls.sf.running: if cls.sf.running is False:
break break
def testError404(self): def testError404(self):
@ -356,12 +361,12 @@ class WebAppTest(unittest.TestCase):
} }
import random import random
f = file('data.csv', 'w') with open('data.csv', 'w') as fh:
for i in range(1024): for i in range(1024):
f.write(chr(random.randint(0, 255))) fh.write(chr(random.randint(0, 255)))
f.close()
files = {'file': open('data.csv', 'rb')} with open('data.csv', 'rb') as fh:
files = {'file': fh}
r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user)
self.assertIn('string=12345', r.text) self.assertIn('string=12345', r.text)
@ -391,12 +396,12 @@ class WebAppTest(unittest.TestCase):
} }
import random import random
f = file('data.txt', 'w') with open('data.txt', 'w') as fh:
for i in range(1024): for i in range(1024):
f.write(chr(random.randint(0, 255))) fh.write(chr(random.randint(0, 255)))
f.close()
files = {'file': open('data.txt', 'rb')} with open('data.txt', 'rb') as fh:
files = {'file': fh}
r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user)
self.assertIn('Maximum length is 7', r.text) self.assertIn('Maximum length is 7', r.text)
@ -431,12 +436,12 @@ class WebAppTest(unittest.TestCase):
} }
import random import random
f = file('data.txt', 'w') with open('data.txt', 'w') as fh:
for i in range(1024): for i in range(1024):
f.write(chr(random.randint(0, 255))) fh.write(chr(random.randint(0, 255)))
f.close()
files = {'file': open('data.txt', 'rb')} with open ('data.txt', 'rb') as fh:
files = {'file': fh}
r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user) r = requests.post("http://localhost:8002/submit", data=data, files=files, auth=self.auth_user)
self.assertIn('value="123"', r.text) self.assertIn('value="123"', r.text)
self.assertIn('value="12"', r.text) self.assertIn('value="12"', r.text)
@ -476,36 +481,36 @@ class WebAppTest(unittest.TestCase):
def testUpload(self): def testUpload(self):
import random import random
f = file('data.raw', 'w') with open('data.raw', 'w') as fh:
for i in range(1024): fh.write(chr(random.randint(0, 255)))
f.write(chr(random.randint(0, 255)))
f.close()
data = { data = {
"form_name": "upload" "form_name": "upload"
} }
files = {'file': open('data.raw', 'rb')} with open('data.raw', 'rb') as fh:
files = {'file': fh}
r = requests.post("http://localhost:8002/submit", files=files, data=data, auth=self.auth_user) r = requests.post("http://localhost:8002/submit", files=files, data=data, auth=self.auth_user)
self.assertIn('SAME', r.text) self.assertIn('SAME', r.text)
os.unlink('data.raw') os.unlink('data.raw')
def testStaticValid(self): def testStaticValid(self):
r = requests.get("http://localhost:8002/static?fname=ssh_server.png", auth=self.auth_user) r = requests.get("http://localhost:8002/static?fname=ssh_server.png", auth=self.auth_user)
self.assertEquals(r.status_code, 200) self.assertEqual(r.status_code, 200)
f_served = b'' f_served = b''
for c in r.iter_content(): for c in r.iter_content():
f_served += c f_served += c
f_orig = file('static/ssh_server.png', 'rb').read() with open('static/ssh_server.png', 'rb')as fh:
self.assertEquals(f_orig, f_served) f_orig = fh.read()
self.assertEqual(f_orig, f_served)
def testStaticInvalidFilename(self): def testStaticInvalidFilename(self):
r = requests.get("http://localhost:8002/static?fname=../../ssh_server.png", auth=self.auth_user) r = requests.get("http://localhost:8002/static?fname=../../ssh_server.png", auth=self.auth_user)
self.assertEquals(r.status_code, 403) self.assertEqual(r.status_code, 403)
def testStaticInvalidNotFound(self): def testStaticInvalidNotFound(self):
r = requests.get("http://localhost:8002/static?fname=nosuchfile.png", auth=self.auth_user) r = requests.get("http://localhost:8002/static?fname=nosuchfile.png", auth=self.auth_user)
self.assertEquals(r.status_code, 404) self.assertEqual(r.status_code, 404)
def testHiddenField(self): def testHiddenField(self):
r = requests.get('http://localhost:8002/form?form_name=hidden_field', auth=self.auth_user) r = requests.get('http://localhost:8002/form?form_name=hidden_field', auth=self.auth_user)
@ -526,14 +531,18 @@ class WebAppSingleTest(unittest.TestCase):
""" """
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Run the server in a thread, so we can execute the tests in the main
# program.
def server_thread(sf): def server_thread(sf):
sf.run(listen_port=8002) sf.run(listen_port=8002)
cls.sf = scriptform.ScriptForm('test_webapp_singleform.json') cls.sf = scriptform.ScriptForm('test_webapp_singleform.json')
thread.start_new_thread(server_thread, (cls.sf, ))
# Wait until the webserver is ready thread = threading.Thread(target=server_thread, args=(cls.sf,))
thread.start()
while True: while True:
time.sleep(0.1) time.sleep(0.1)
if cls.sf.running: if cls.sf.running is True:
break break
@classmethod @classmethod
@ -555,7 +564,7 @@ class WebAppSingleTest(unittest.TestCase):
""" """
""" """
r = requests.get("http://localhost:8002/static?fname=nosuchfile.png") r = requests.get("http://localhost:8002/static?fname=nosuchfile.png")
self.assertEquals(r.status_code, 501) self.assertEqual(r.status_code, 501)
if __name__ == '__main__': if __name__ == '__main__':
@ -575,11 +584,11 @@ if __name__ == '__main__':
cov.stop() cov.stop()
cov.save() cov.save()
print cov.report() print(cov.report())
try: try:
print cov.html_report() print(cov.html_report())
except coverage.misc.CoverageException, e: except coverage.misc.CoverageException as err:
if "Couldn't find static file 'jquery.hotkeys.js'" in e.message: if "Couldn't find static file 'jquery.hotkeys.js'" in err.message:
pass pass
else: else:
raise raise

Loading…
Cancel
Save