Commit 3abc2eca authored by Sean Bleier's avatar Sean Bleier

Refactor class importing.

parent 22c233de
...@@ -3,8 +3,6 @@ from django.core.exceptions import ImproperlyConfigured ...@@ -3,8 +3,6 @@ from django.core.exceptions import ImproperlyConfigured
from django.utils import importlib from django.utils import importlib
from django.utils.importlib import import_module from django.utils.importlib import import_module
from redis_cache.compat import smart_bytes, DEFAULT_TIMEOUT
try: try:
import redis import redis
except ImportError: except ImportError:
...@@ -14,8 +12,11 @@ except ImportError: ...@@ -14,8 +12,11 @@ except ImportError:
from redis.connection import DefaultParser from redis.connection import DefaultParser
from redis_cache.compat import smart_bytes, DEFAULT_TIMEOUT
from redis_cache.connection import pool from redis_cache.connection import pool
from redis_cache.utils import CacheKey, get_servers, parse_connection_kwargs from redis_cache.utils import (
CacheKey, get_servers, parse_connection_kwargs, import_class
)
from functools import wraps from functools import wraps
...@@ -92,18 +93,10 @@ class BaseRedisCache(BaseCache): ...@@ -92,18 +93,10 @@ class BaseRedisCache(BaseCache):
return self.params.get('password', self.options.get('PASSWORD', None)) return self.params.get('password', self.options.get('PASSWORD', None))
def get_parser_class(self): def get_parser_class(self):
cls = self.options.get('PARSER_CLASS', None) parser_class = self.options.get('PARSER_CLASS', None)
if cls is None: if parser_class is None:
return DefaultParser return DefaultParser
mod_path, cls_name = cls.rsplit('.', 1) return import_class(parser_class)
try:
mod = importlib.import_module(mod_path)
parser_class = getattr(mod, cls_name)
except AttributeError:
raise ImproperlyConfigured("Could not find parser class '%s'" % parser_class)
except ImportError as ex:
raise ImproperlyConfigured("Could not find module '%s'" % ex)
return parser_class
def get_pickle_version(self): def get_pickle_version(self):
""" """
...@@ -120,12 +113,7 @@ class BaseRedisCache(BaseCache): ...@@ -120,12 +113,7 @@ class BaseRedisCache(BaseCache):
def get_connection_pool_class(self): def get_connection_pool_class(self):
pool_class = self.options.get('CONNECTION_POOL_CLASS', 'redis.ConnectionPool') pool_class = self.options.get('CONNECTION_POOL_CLASS', 'redis.ConnectionPool')
module_name, class_name = pool_class.rsplit('.', 1) return import_class(pool_class)
module = import_module(module_name)
try:
return getattr(module, class_name)
except AttributeError:
raise ImportError('cannot import name %s' % class_name)
def get_connection_pool_class_kwargs(self): def get_connection_pool_class_kwargs(self):
return self.options.get('CONNECTION_POOL_CLASS_KWARGS', {}) return self.options.get('CONNECTION_POOL_CLASS_KWARGS', {})
...@@ -135,12 +123,7 @@ class BaseRedisCache(BaseCache): ...@@ -135,12 +123,7 @@ class BaseRedisCache(BaseCache):
'SERIALIZER_CLASS', 'SERIALIZER_CLASS',
'redis_cache.serializers.PickleSerializer' 'redis_cache.serializers.PickleSerializer'
) )
module_name, class_name = serializer_class.rsplit('.', 1) return import_class(serializer_class)
module = import_module(module_name)
try:
return getattr(module, class_name)
except AttributeError:
raise ImportError('cannot import name %s' % class_name)
def get_serializer_class_kwargs(self): def get_serializer_class_kwargs(self):
return self.options.get('SERIALIZER_CLASS_KWARGS', {}) return self.options.get('SERIALIZER_CLASS_KWARGS', {})
...@@ -150,12 +133,7 @@ class BaseRedisCache(BaseCache): ...@@ -150,12 +133,7 @@ class BaseRedisCache(BaseCache):
'COMPRESSOR_CLASS', 'COMPRESSOR_CLASS',
'redis_cache.compressors.NoopCompressor' 'redis_cache.compressors.NoopCompressor'
) )
module_name, class_name = compressor_class.rsplit('.', 1) return import_class(compressor_class)
module = import_module(module_name)
try:
return getattr(module, class_name)
except AttributeError:
raise ImportError('cannot import name %s' % class_name)
def get_compressor_class_kwargs(self): def get_compressor_class_kwargs(self):
return self.options.get('COMPRESSOR_CLASS_KWARGS', {}) return self.options.get('COMPRESSOR_CLASS_KWARGS', {})
......
import warnings import warnings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.utils.importlib import import_module
from redis.connection import SSLConnection from redis.connection import SSLConnection
from redis_cache.compat import ( from redis_cache.compat import (
...@@ -49,6 +50,19 @@ def get_servers(location): ...@@ -49,6 +50,19 @@ def get_servers(location):
return servers return servers
def import_class(path):
module_name, class_name = path.rsplit('.', 1)
try:
module = import_module(module_name)
except ImportError:
raise ImproperlyConfigured('Could not find module "%s"' % module_name)
else:
try:
return getattr(module, class_name)
except AttributeError:
raise ImproperlyConfigured('Cannot import "%s"' % class_name)
def parse_connection_kwargs(server, db=None, **kwargs): def parse_connection_kwargs(server, db=None, **kwargs):
""" """
Return a connection pool configured from the given URL. Return a connection pool configured from the given URL.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment