master
/ .localenv / lib / python3.5 / site-packages / jupyter_client / tests / test_connect.py

test_connect.py @master raw · history · blame

"""Tests for kernel connection utilities"""

# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import json
import os

from traitlets.config import Config
from jupyter_core.application import JupyterApp
from jupyter_core.paths import jupyter_runtime_dir
from ipython_genutils.tempdir import TemporaryDirectory, TemporaryWorkingDirectory
from ipython_genutils.py3compat import str_to_bytes
from jupyter_client import connect, KernelClient
from jupyter_client.consoleapp import JupyterConsoleApp
from jupyter_client.session import Session


class DummyConsoleApp(JupyterApp, JupyterConsoleApp):
    def initialize(self, argv=[]):
        JupyterApp.initialize(self, argv=argv)
        self.init_connection_file()

class DummyConfigurable(connect.ConnectionFileMixin):
    def initialize(self):
        pass

sample_info = dict(ip='1.2.3.4', transport='ipc',
        shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5,
        key=b'abc123', signature_scheme='hmac-md5', kernel_name='python'
    )

sample_info_kn = dict(ip='1.2.3.4', transport='ipc',
        shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5,
        key=b'abc123', signature_scheme='hmac-md5', kernel_name='test'
    )


def test_write_connection_file():
    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info)
        assert os.path.exists(cf)
        with open(cf, 'r') as f:
            info = json.load(f)
    info['key'] = str_to_bytes(info['key'])
    assert info == sample_info


def test_load_connection_file_session():
    """test load_connection_file() after """
    session = Session()
    app = DummyConsoleApp(session=Session())
    app.initialize(argv=[])
    session = app.session

    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info)
        app.connection_file = cf
        app.load_connection_file()

    assert session.key == sample_info['key']
    assert session.signature_scheme == sample_info['signature_scheme']


def test_load_connection_file_session_with_kn():
    """test load_connection_file() after """
    session = Session()
    app = DummyConsoleApp(session=Session())
    app.initialize(argv=[])
    session = app.session

    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info_kn)
        app.connection_file = cf
        app.load_connection_file()

    assert session.key == sample_info_kn['key']
    assert session.signature_scheme == sample_info_kn['signature_scheme']


def test_app_load_connection_file():
    """test `ipython console --existing` loads a connection file"""
    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info)
        app = DummyConsoleApp(connection_file=cf)
        app.initialize(argv=[])

    for attr, expected in sample_info.items():
        if attr in ('key', 'signature_scheme'):
            continue
        value = getattr(app, attr)
        assert value == expected, "app.%s = %s != %s" % (attr, value, expected)


def test_load_connection_info():
    client = KernelClient()
    info = {
        'control_port': 53702,
        'hb_port': 53705,
        'iopub_port': 53703,
        'ip': '0.0.0.0',
        'key': 'secret',
        'shell_port': 53700,
        'signature_scheme': 'hmac-sha256',
        'stdin_port': 53701,
        'transport': 'tcp',
    }
    client.load_connection_info(info)
    assert client.control_port == info['control_port']
    assert client.session.key.decode('ascii') == info['key']
    assert client.ip == info['ip']


def test_find_connection_file():
    with TemporaryDirectory() as d:
        cf = 'kernel.json'
        app = DummyConsoleApp(runtime_dir=d, connection_file=cf)
        app.initialize()

        security_dir = app.runtime_dir
        profile_cf = os.path.join(security_dir, cf)

        with open(profile_cf, 'w') as f:
            f.write("{}")

        for query in (
            'kernel.json',
            'kern*',
            '*ernel*',
            'k*',
            ):
            assert connect.find_connection_file(query, path=security_dir) == profile_cf


def test_find_connection_file_local():
    with TemporaryWorkingDirectory() as d:
        cf = 'test.json'
        abs_cf = os.path.abspath(cf)
        with open(cf, 'w') as f:
            f.write('{}')
        
        for query in (
            'test.json',
            'test',
            abs_cf,
            os.path.join('.', 'test.json'),
        ):
            assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf


def test_find_connection_file_relative():
    with TemporaryWorkingDirectory() as d:
        cf = 'test.json'
        os.mkdir('subdir')
        cf = os.path.join('subdir', 'test.json')
        abs_cf = os.path.abspath(cf)
        with open(cf, 'w') as f:
            f.write('{}')
        
        for query in (
            os.path.join('.', 'subdir', 'test.json'),
            os.path.join('subdir', 'test.json'),
            abs_cf,
        ):
            assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf


def test_find_connection_file_abspath():
    with TemporaryDirectory() as d:
        cf = 'absolute.json'
        abs_cf = os.path.abspath(cf)
        with open(cf, 'w') as f:
            f.write('{}')
        assert connect.find_connection_file(abs_cf, path=jupyter_runtime_dir()) == abs_cf
        os.remove(abs_cf)


def test_mixin_record_random_ports():
    with TemporaryDirectory() as d:
        dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp')
        dc.write_connection_file()

        assert dc._connection_file_written
        assert os.path.exists(dc.connection_file)
        assert dc._random_port_names == connect.port_names


def test_mixin_cleanup_random_ports():
    with TemporaryDirectory() as d:
        dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp')
        dc.write_connection_file()
        filename = dc.connection_file
        dc.cleanup_random_ports()

        assert not os.path.exists(filename)
        for name in dc._random_port_names:
            assert getattr(dc, name) == 0