Commit 9cfe32b3 authored by Sean Bleier's avatar Sean Bleier

Adds ability to specify connection pool class.

parent a7bb91b0
...@@ -2,6 +2,7 @@ from django.core.cache.backends.base import BaseCache, InvalidCacheBackendError ...@@ -2,6 +2,7 @@ from django.core.cache.backends.base import BaseCache, InvalidCacheBackendError
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.utils import importlib from django.utils import importlib
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.importlib import import_module
from .compat import (smart_text, smart_bytes, bytes_type, from .compat import (smart_text, smart_bytes, bytes_type,
python_2_unicode_compatible, DEFAULT_TIMEOUT) python_2_unicode_compatible, DEFAULT_TIMEOUT)
...@@ -47,8 +48,8 @@ class CacheConnectionPool(object): ...@@ -47,8 +48,8 @@ class CacheConnectionPool(object):
def get_connection_pool(self, host='127.0.0.1', port=6379, db=1, def get_connection_pool(self, host='127.0.0.1', port=6379, db=1,
password=None, parser_class=None, password=None, parser_class=None,
unix_socket_path=None, max_connections=None): unix_socket_path=None, max_connections=None, connection_pool_class=None):
connection_identifier = (host, port, db, parser_class, unix_socket_path, max_connections) connection_identifier = (host, port, db, parser_class, unix_socket_path, max_connections, connection_pool_class)
if not self._connection_pools.get(connection_identifier): if not self._connection_pools.get(connection_identifier):
connection_class = ( connection_class = (
unix_socket_path and UnixDomainSocketConnection or Connection unix_socket_path and UnixDomainSocketConnection or Connection
...@@ -67,7 +68,7 @@ class CacheConnectionPool(object): ...@@ -67,7 +68,7 @@ class CacheConnectionPool(object):
}) })
else: else:
kwargs['path'] = unix_socket_path kwargs['path'] = unix_socket_path
self._connection_pools[connection_identifier] = redis.ConnectionPool(**kwargs) self._connection_pools[connection_identifier] = connection_pool_class(**kwargs)
return self._connection_pools[connection_identifier] return self._connection_pools[connection_identifier]
pool = CacheConnectionPool() pool = CacheConnectionPool()
...@@ -106,6 +107,7 @@ class CacheClass(BaseCache): ...@@ -106,6 +107,7 @@ class CacheClass(BaseCache):
connection_pool = pool.get_connection_pool( connection_pool = pool.get_connection_pool(
parser_class=self.parser_class, parser_class=self.parser_class,
max_connections=self.max_connections, max_connections=self.max_connections,
connection_pool_class=self.connection_pool_class,
**kwargs **kwargs
) )
self._client = redis.Redis( self._client = redis.Redis(
...@@ -129,6 +131,17 @@ class CacheClass(BaseCache): ...@@ -129,6 +131,17 @@ class CacheClass(BaseCache):
def max_connections(self): def max_connections(self):
return self.options.get('MAX_CONNECTIONS', None) return self.options.get('MAX_CONNECTIONS', None)
@property
def connection_pool_class(self):
pool_class = self.options.get('CONNECTION_POOL_CLASS', 'redis.ConnectionPool')
module_name, class_name = pool_class.rsplit('.', 1)
module = import_module(module_name)
try:
cls = getattr(module, class_name)
except AttributeError:
raise ImportError('cannot import name %s' % class_name)
return cls
@property @property
def db(self): def db(self):
_db = self.params.get('db', self.options.get('DB', 1)) _db = self.params.get('db', self.options.get('DB', 1))
......
...@@ -23,6 +23,7 @@ cache_settings = { ...@@ -23,6 +23,7 @@ cache_settings = {
'PASSWORD': 'yadayada', 'PASSWORD': 'yadayada',
'PARSER_CLASS': 'redis.connection.HiredisParser', 'PARSER_CLASS': 'redis.connection.HiredisParser',
'MAX_CONNECTIONS': 2, 'MAX_CONNECTIONS': 2,
'CONNECTION_POOL_CLASS': 'redis.ConnectionPool',
}, },
}, },
}, },
......
...@@ -379,18 +379,12 @@ class RedisCacheTests(TestCase): ...@@ -379,18 +379,12 @@ class RedisCacheTests(TestCase):
release = cache._client.connection_pool.release release = cache._client.connection_pool.release
cache._client.connection_pool.release = noop cache._client.connection_pool.release = noop
self.assertEqual(cache._client.connection_pool.max_connections, 2)
cache.set('a', 'a') cache.set('a', 'a')
self.assertEqual(cache._client.connection_pool._created_connections, 1)
cache.set('a', 'a') cache.set('a', 'a')
self.assertEqual(cache._client.connection_pool._created_connections, 2)
with self.assertRaises(redis.ConnectionError): with self.assertRaises(redis.ConnectionError):
cache.set('a', 'a') cache.set('a', 'a')
self.assertEqual(cache._client.connection_pool._created_connections, 2)
cache._client.connection_pool.release = release cache._client.connection_pool.release = release
cache._client.connection_pool.max_connections = 2**31 cache._client.connection_pool.max_connections = 2**31
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment