Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add client_session_keep_alive parameter #116

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .sqlstate import (SQLSTATE_CONNECTION_NOT_EXISTS,
SQLSTATE_FEATURE_NOT_SUPPORTED)
from .telemetry import (TelemetryClient)
from .time_util import get_time_millis
from .time_util import HourlyTimer, get_time_millis
from .util_text import split_statements, construct_hostname

SUPPORTED_PARAMSTYLES = {
Expand Down Expand Up @@ -93,6 +93,7 @@
u'inject_client_pause': 0, # snowflake internal
u'session_parameters': {}, # snowflake session parameters
u'autocommit': None, # snowflake
u'client_session_keep_alive': False, # snowflake
u'numpy': False, # snowflake
u'ocsp_response_cache_filename': None, # snowflake internal
u'converter_class':
Expand Down Expand Up @@ -139,6 +140,8 @@ def __init__(self, **kwargs):
for name, value in DEFAULT_CONFIGURATION.items():
setattr(self, u'_' + name, value)

self.heartbeat_thread = None

self.converter = None
self.connect(**kwargs)
self._telemetry = TelemetryClient(self._rest)
Expand Down Expand Up @@ -273,6 +276,13 @@ def network_timeout(self):
return int(self._network_timeout) if self._network_timeout is not \
None else None

@property
def keep_alive(self):
u"""
Keep connection alive by issuing a hourly heartbeat (SELECT 1;).
"""
return self._client_session_keep_alive

@property
def rest(self):
u"""
Expand Down Expand Up @@ -344,6 +354,8 @@ def close(self):
if not self.rest:
return

self.cancel_heartbeat()

# close telemetry first, since it needs rest to send remaining data
logger.info('closed')
self._telemetry.close()
Expand Down Expand Up @@ -532,6 +544,9 @@ def __open_connection(self):

self._password = None # ensure password won't persist

if self.keep_alive:
self.add_heartbeat()

def __config(self, **kwargs):
u"""
Sets the parameters
Expand Down Expand Up @@ -891,6 +906,24 @@ def set_telemetry_enabled(self, enabled):
"""
self._telemetry_enabled = enabled

def add_heartbeat(self):
"""Add an hourly heartbeat query in order to keep connection alive."""
if not self.heartbeat_thread:
self.heartbeat_thread = HourlyTimer(self.heartbeat_tick)
self.heartbeat_thread.start()

def cancel_heartbeat(self):
"""Cancel a heartbeat thread."""
if self.heartbeat_thread:
self.heartbeat_thread.cancel()
self.heartbeat_thread.join()
self.heartbeat_thread = None

def heartbeat_tick(self):
"""Execute a hearbeat query if connection isn't closed yet."""
if not self.is_closed():
self.execute_string("SELECT 1;")

def __enter__(self):
u"""
context manager
Expand Down
25 changes: 25 additions & 0 deletions test/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,34 @@ def test_with_config(db_parameters):
}
cnx = snowflake.connector.connect(**config)
assert cnx, 'invalid cnx'

assert not cnx.keep_alive # default is False

cnx.close()


def test_keep_alive_true(db_parameters):
"""
Creates a connection with client_session_keep_alive parameter.
"""
config = {
'user': db_parameters['user'],
'password': db_parameters['password'],
'host': db_parameters['host'],
'port': db_parameters['port'],
'account': db_parameters['account'],
'schema': db_parameters['schema'],
'database': db_parameters['database'],
'protocol': db_parameters['protocol'],
'timezone': 'UTC',
'client_session_keep_alive': True
}
cnx = snowflake.connector.connect(**config)
assert cnx.keep_alive
cnx.close()



def test_bad_db(db_parameters):
"""
Attempts to use a bad DB
Expand Down
19 changes: 19 additions & 0 deletions time_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,25 @@
#
# Copyright (c) 2018 Snowflake Computing Inc. All right reserved.
import time
try:
from threading import _Timer as Timer
except ImportError:
from threading import Timer


class HourlyTimer(Timer):
"""A thread which executes a function every hour."""

def __init__(self, function, args=None, kwargs={}):
super(HourlyTimer, self).__init__(
self, function, args=args, kwargs=kwargs)
self.interval = 60.0 * 60 # one hour

def run(self):
while not self.finished.is_set():
self.finished.wait(self.interval)
if not self.finished.is_set():
self.function(*self.args, **self.kwargs)


def get_time_millis():
Expand Down