Commit 9cfe32b3 authored by Sean Bleier's avatar Sean Bleier

Adds ability to specify connection pool class.

parent a7bb91b0
......@@ -2,6 +2,7 @@ from django.core.cache.backends.base import BaseCache, InvalidCacheBackendError
from django.core.exceptions import ImproperlyConfigured
from django.utils import importlib
from django.utils.datastructures import SortedDict
from django.utils.importlib import import_module
from .compat import (smart_text, smart_bytes, bytes_type,
python_2_unicode_compatible, DEFAULT_TIMEOUT)
......@@ -47,8 +48,8 @@ class CacheConnectionPool(object):
def get_connection_pool(self, host='127.0.0.1', port=6379, db=1,
password=None, parser_class=None,
unix_socket_path=None, max_connections=None):
connection_identifier = (host, port, db, parser_class, unix_socket_path, max_connections)
unix_socket_path=None, max_connections=None, connection_pool_class=None):
connection_identifier = (host, port, db, parser_class, unix_socket_path, max_connections, connection_pool_class)
if not self._connection_pools.get(connection_identifier):
connection_class = (
unix_socket_path and UnixDomainSocketConnection or Connection
......@@ -67,7 +68,7 @@ class CacheConnectionPool(object):
})
else:
kwargs['path'] = unix_socket_path
self._connection_pools[connection_identifier] = redis.ConnectionPool(**kwargs)
self._connection_pools[connection_identifier] = connection_pool_class(**kwargs)
return self._connection_pools[connection_identifier]
pool = CacheConnectionPool()
......@@ -106,6 +107,7 @@ class CacheClass(BaseCache):
connection_pool = pool.get_connection_pool(
parser_class=self.parser_class,
max_connections=self.max_connections,
connection_pool_class=self.connection_pool_class,
**kwargs
)
self._client = redis.Redis(
......@@ -129,6 +131,17 @@ class CacheClass(BaseCache):
def max_connections(self):
return self.options.get('MAX_CONNECTIONS', None)
@property
def connection_pool_class(self):
pool_class = self.options.get('CONNECTION_POOL_CLASS', 'redis.ConnectionPool')
module_name, class_name = pool_class.rsplit('.', 1)
module = import_module(module_name)
try:
cls = getattr(module, class_name)
except AttributeError:
raise ImportError('cannot import name %s' % class_name)
return cls
@property
def db(self):
_db = self.params.get('db', self.options.get('DB', 1))
......
......@@ -23,6 +23,7 @@ cache_settings = {
'PASSWORD': 'yadayada',
'PARSER_CLASS': 'redis.connection.HiredisParser',
'MAX_CONNECTIONS': 2,
'CONNECTION_POOL_CLASS': 'redis.ConnectionPool',
},
},
},
......
......@@ -379,18 +379,12 @@ class RedisCacheTests(TestCase):
release = cache._client.connection_pool.release
cache._client.connection_pool.release = noop
self.assertEqual(cache._client.connection_pool.max_connections, 2)
cache.set('a', 'a')
self.assertEqual(cache._client.connection_pool._created_connections, 1)
cache.set('a', 'a')
self.assertEqual(cache._client.connection_pool._created_connections, 2)
with self.assertRaises(redis.ConnectionError):
cache.set('a', 'a')
self.assertEqual(cache._client.connection_pool._created_connections, 2)
cache._client.connection_pool.release = release
cache._client.connection_pool.max_connections = 2**31
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment