Commit f15f5d40 authored by Sean Bleier's avatar Sean Bleier

Update to support Redis-py 3.x

parent 777d4d55
...@@ -19,3 +19,4 @@ MANIFEST ...@@ -19,3 +19,4 @@ MANIFEST
.venv .venv
redis/ redis/
*/_build/ */_build/
build/*
language: python language: python
python: python:
- 2.7 - 2.7
- 3.4
- 3.5 - 3.5
- 3.6 - 3.6
env: env:
......
...@@ -21,6 +21,16 @@ Docs can be found at http://django-redis-cache.readthedocs.org/en/latest/. ...@@ -21,6 +21,16 @@ Docs can be found at http://django-redis-cache.readthedocs.org/en/latest/.
Changelog Changelog
========= =========
2.0.0
-----
* Adds support for redis-py >= 3.0.
* Drops support for Redis 2.6
* Drops support for Python 3.4
* Removes custom ``expire`` method in lieu of Django's ``touch``.
* Removes ``CacheKey`` in favor of string literals.
1.8.0 1.8.0
----- -----
......
from functools import wraps
from django.core.cache.backends.base import ( from django.core.cache.backends.base import (
BaseCache, DEFAULT_TIMEOUT, InvalidCacheBackendError, BaseCache, DEFAULT_TIMEOUT, InvalidCacheBackendError,
) )
...@@ -11,14 +13,9 @@ except ImportError: ...@@ -11,14 +13,9 @@ except ImportError:
) )
from redis.connection import DefaultParser from redis.connection import DefaultParser
from redis_cache.constants import KEY_EXPIRED, KEY_NON_VOLATILE
from redis_cache.connection import pool from redis_cache.connection import pool
from redis_cache.utils import ( from redis_cache.utils import get_servers, parse_connection_kwargs, import_class
CacheKey, get_servers, parse_connection_kwargs, import_class
)
from functools import wraps
def get_client(write=False): def get_client(write=False):
...@@ -28,8 +25,8 @@ def get_client(write=False): ...@@ -28,8 +25,8 @@ def get_client(write=False):
@wraps(method) @wraps(method)
def wrapped(self, key, *args, **kwargs): def wrapped(self, key, *args, **kwargs):
version = kwargs.pop('version', None) version = kwargs.pop('version', None)
client = self.get_client(key, write=write)
key = self.make_key(key, version=version) key = self.make_key(key, version=version)
client = self.get_client(key, write=write)
return method(self, client, key, *args, **kwargs) return method(self, client, key, *args, **kwargs)
return wrapped return wrapped
...@@ -76,6 +73,12 @@ class BaseRedisCache(BaseCache): ...@@ -76,6 +73,12 @@ class BaseRedisCache(BaseCache):
**self.compressor_class_kwargs **self.compressor_class_kwargs
) )
redis_py_version = (int(part) for part in redis.__version__.split('.'))
if redis_py_version < (3, 0, 0):
self.Redis = redis.StrictRedis
else:
self.Redis = redis.Redis
def __getstate__(self): def __getstate__(self):
return {'params': self.params, 'server': self.server} return {'params': self.params, 'server': self.server}
...@@ -180,7 +183,7 @@ class BaseRedisCache(BaseCache): ...@@ -180,7 +183,7 @@ class BaseRedisCache(BaseCache):
socket_timeout=self.socket_timeout, socket_timeout=self.socket_timeout,
socket_connect_timeout=self.socket_connect_timeout, socket_connect_timeout=self.socket_connect_timeout,
) )
client = redis.Redis(**kwargs) client = self.Redis(**kwargs)
kwargs.update( kwargs.update(
parser_class=self.parser_class, parser_class=self.parser_class,
connection_pool_class=self.connection_pool_class, connection_pool_class=self.connection_pool_class,
...@@ -216,12 +219,6 @@ class BaseRedisCache(BaseCache): ...@@ -216,12 +219,6 @@ class BaseRedisCache(BaseCache):
value = self.serialize(value) value = self.serialize(value)
return self.compress(value) return self.compress(value)
def make_key(self, key, version=None):
if not isinstance(key, CacheKey):
versioned_key = super(BaseRedisCache, self).make_key(key, version)
return CacheKey(key, versioned_key)
return key
def make_keys(self, keys, version=None): def make_keys(self, keys, version=None):
return [self.make_key(key, version=version) for key in keys] return [self.make_key(key, version=version) for key in keys]
...@@ -247,41 +244,34 @@ class BaseRedisCache(BaseCache): ...@@ -247,41 +244,34 @@ class BaseRedisCache(BaseCache):
timeout = self.get_timeout(timeout) timeout = self.get_timeout(timeout)
return self._set(client, key, self.prep_value(value), timeout, _add_only=True) return self._set(client, key, self.prep_value(value), timeout, _add_only=True)
def _get(self, client, key, default=None):
value = client.get(key)
if value is None:
return default
value = self.get_value(value)
return value
@get_client() @get_client()
def get(self, client, key, default=None): def get(self, client, key, default=None):
"""Retrieve a value from the cache. """Retrieve a value from the cache.
Returns deserialized value if key is found, the default if not. Returns deserialized value if key is found, the default if not.
""" """
value = client.get(key) return self._get(client, key, default)
if value is None:
return default
value = self.get_value(value)
return value
def _set(self, client, key, value, timeout, _add_only=False): def _set(self, client, key, value, timeout, _add_only=False):
if timeout is None or timeout == 0: if timeout is not None and timeout < 0:
if _add_only:
return client.setnx(key, value)
return client.set(key, value)
elif timeout > 0:
if _add_only:
added = client.setnx(key, value)
if added:
client.expire(key, timeout)
return added
return client.setex(key, value, timeout)
else:
return False return False
elif timeout == 0:
return client.expire(key, 0)
return client.set(key, value, nx=_add_only, ex=timeout)
@get_client(write=True) @get_client(write=True)
def set(self, client, key, value, timeout=DEFAULT_TIMEOUT): def set(self, client, key, value, timeout=DEFAULT_TIMEOUT):
"""Persist a value to the cache, and set an optional expiration time. """Persist a value to the cache, and set an optional expiration time.
""" """
timeout = self.get_timeout(timeout) timeout = self.get_timeout(timeout)
result = self._set(client, key, self.prep_value(value), timeout, _add_only=False) result = self._set(client, key, self.prep_value(value), timeout, _add_only=False)
return result return result
@get_client(write=True) @get_client(write=True)
...@@ -328,12 +318,6 @@ class BaseRedisCache(BaseCache): ...@@ -328,12 +318,6 @@ class BaseRedisCache(BaseCache):
"""Retrieve many keys.""" """Retrieve many keys."""
raise NotImplementedError raise NotImplementedError
def _set_many(self, client, data):
# Only call mset if there actually is some data to save
if not data:
return True
return client.mset(data)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
"""Set a bunch of values in the cache at once from a dict of key/value """Set a bunch of values in the cache at once from a dict of key/value
pairs. This is much more efficient than calling set() multiple times. pairs. This is much more efficient than calling set() multiple times.
...@@ -351,26 +335,27 @@ class BaseRedisCache(BaseCache): ...@@ -351,26 +335,27 @@ class BaseRedisCache(BaseCache):
exists = client.exists(key) exists = client.exists(key)
if not exists: if not exists:
raise ValueError("Key '%s' not found" % key) raise ValueError("Key '%s' not found" % key)
try:
value = client.incr(key, delta) value = client.incr(key, delta)
except redis.ResponseError:
key = key._original_key
value = self.get(key) + delta
self.set(key, value, timeout=None)
return value return value
def _incr_version(self, client, old, new, delta, version): def _incr_version(self, client, old, new, original, delta, version):
try: try:
client.rename(old, new) client.rename(old, new)
except redis.ResponseError: except redis.ResponseError:
raise ValueError("Key '%s' not found" % old._original_key) raise ValueError("Key '%s' not found" % original)
return version + delta return version + delta
def incr_version(self, key, delta=1, version=None): def incr_version(self, key, delta=1, version=None):
"""Adds delta to the cache version for the supplied key. Returns the """Adds delta to the cache version for the supplied key. Returns the
new version. new version.
""" """
raise NotImplementedError
@get_client(write=True)
def touch(self, client, key, timeout=DEFAULT_TIMEOUT):
"""Reset the timeout of a key to `timeout` seconds."""
return client.expire(key, timeout)
##################### #####################
# Extra api methods # # Extra api methods #
...@@ -388,9 +373,13 @@ class BaseRedisCache(BaseCache): ...@@ -388,9 +373,13 @@ class BaseRedisCache(BaseCache):
Otherwise, the value is the number of seconds remaining. If the key Otherwise, the value is the number of seconds remaining. If the key
does not exist, 0 is returned. does not exist, 0 is returned.
""" """
if client.exists(key): ttl = client.ttl(key)
return client.ttl(key) if ttl == KEY_NON_VOLATILE:
return None
elif ttl == KEY_EXPIRED:
return 0 return 0
else:
return ttl
def _delete_pattern(self, client, pattern): def _delete_pattern(self, client, pattern):
keys = list(client.scan_iter(match=pattern)) keys = list(client.scan_iter(match=pattern))
...@@ -405,11 +394,11 @@ class BaseRedisCache(BaseCache): ...@@ -405,11 +394,11 @@ class BaseRedisCache(BaseCache):
if not callable(func): if not callable(func):
raise Exception("Must pass in a callable") raise Exception("Must pass in a callable")
value = self.get(key._original_key) value = self._get(client, key)
if value is None: if value is None:
dogpile_lock_key = "_lock" + key._versioned_key dogpile_lock_key = "_lock" + key
dogpile_lock = client.get(dogpile_lock_key) dogpile_lock = client.get(dogpile_lock_key)
if dogpile_lock is None: if dogpile_lock is None:
...@@ -453,12 +442,3 @@ class BaseRedisCache(BaseCache): ...@@ -453,12 +442,3 @@ class BaseRedisCache(BaseCache):
Returns True if successful and False if not. Returns True if successful and False if not.
""" """
return client.persist(key) return client.persist(key)
@get_client(write=True)
def expire(self, client, key, timeout):
"""
Set the expire time on a key
returns True if successful and False if not.
"""
return client.expire(key, timeout)
...@@ -29,9 +29,8 @@ class ShardedRedisCache(BaseRedisCache): ...@@ -29,9 +29,8 @@ class ShardedRedisCache(BaseRedisCache):
""" """
clients = defaultdict(list) clients = defaultdict(list)
for key in keys: for key in keys:
clients[self.get_client(key, write)].append( versioned_key = self.make_key(key, version=version)
self.make_key(key, version) clients[self.get_client(versioned_key, write)].append(versioned_key)
)
return clients return clients
#################### ####################
...@@ -63,41 +62,28 @@ class ShardedRedisCache(BaseRedisCache): ...@@ -63,41 +62,28 @@ class ShardedRedisCache(BaseRedisCache):
data = {} data = {}
clients = self.shard(keys, version=version) clients = self.shard(keys, version=version)
for client, versioned_keys in clients.items(): for client, versioned_keys in clients.items():
original_keys = [key._original_key for key in versioned_keys] versioned_keys = [self.make_key(key, version=version) for key in keys]
data.update( data.update(
self._get_many( self._get_many(client, keys, versioned_keys=versioned_keys)
client,
original_keys,
versioned_keys=versioned_keys
)
) )
return data return data
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
""" """
Set a bunch of values in the cache at once from a dict of key/value Set multiple values in the cache at once from a dict of key/value pairs.
pairs. This is much more efficient than calling set() multiple times.
If timeout is given, that timeout will be used for the key; otherwise If timeout is given, that timeout will be used for the key; otherwise
the default cache timeout will be used. the default cache timeout will be used.
""" """
timeout = self.get_timeout(timeout) timeout = self.get_timeout(timeout)
versioned_key_to_key = {self.make_key(key, version=version): key for key in data.keys()}
clients = self.shard(versioned_key_to_key.values(), write=True, version=version)
clients = self.shard(data.keys(), write=True, version=version) for client, versioned_keys in clients.items():
if timeout is None:
for client, keys in clients.items():
subset = {}
for key in keys:
subset[key] = self.prep_value(data[key._original_key])
self._set_many(client, subset)
return
for client, keys in clients.items():
pipeline = client.pipeline() pipeline = client.pipeline()
for key in keys: for versioned_key in versioned_keys:
value = self.prep_value(data[key._original_key]) value = self.prep_value(data[versioned_key_to_key[versioned_key]])
self._set(pipeline, key, value, timeout) self._set(pipeline, versioned_key, value, timeout)
pipeline.execute() pipeline.execute()
def incr_version(self, key, delta=1, version=None): def incr_version(self, key, delta=1, version=None):
...@@ -113,7 +99,7 @@ class ShardedRedisCache(BaseRedisCache): ...@@ -113,7 +99,7 @@ class ShardedRedisCache(BaseRedisCache):
old = self.make_key(key, version=version) old = self.make_key(key, version=version)
new = self.make_key(key, version=version + delta) new = self.make_key(key, version=version + delta)
return self._incr_version(client, old, new, delta, version) return self._incr_version(client, old, new, key, delta, version)
##################### #####################
# Extra api methods # # Extra api methods #
......
...@@ -56,25 +56,18 @@ class RedisCache(BaseRedisCache): ...@@ -56,25 +56,18 @@ class RedisCache(BaseRedisCache):
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
""" """
Set a bunch of values in the cache at once from a dict of key/value Set multiple values in the cache at once from a dict of key/value pairs.
pairs. This is much more efficient than calling set() multiple times.
If timeout is given, that timeout will be used for the key; otherwise If timeout is given, that timeout will be used for the key; otherwise
the default cache timeout will be used. the default cache timeout will be used.
""" """
timeout = self.get_timeout(timeout) timeout = self.get_timeout(timeout)
versioned_keys = self.make_keys(data.keys(), version=version)
if timeout is None:
new_data = {}
for key in versioned_keys:
new_data[key] = self.prep_value(data[key._original_key])
return self._set_many(self.master_client, new_data)
pipeline = self.master_client.pipeline() pipeline = self.master_client.pipeline()
for key in versioned_keys: for key, value in data.items():
value = self.prep_value(data[key._original_key]) value = self.prep_value(value)
self._set(pipeline, key, value, timeout) versioned_key = self.make_key(key, version=version)
self._set(pipeline, versioned_key, value, timeout)
pipeline.execute() pipeline.execute()
def incr_version(self, key, delta=1, version=None): def incr_version(self, key, delta=1, version=None):
...@@ -89,7 +82,7 @@ class RedisCache(BaseRedisCache): ...@@ -89,7 +82,7 @@ class RedisCache(BaseRedisCache):
old = self.make_key(key, version) old = self.make_key(key, version)
new = self.make_key(key, version=version + delta) new = self.make_key(key, version=version + delta)
return self._incr_version(self.master_client, old, new, delta, version) return self._incr_version(self.master_client, old, new, key, delta, version)
##################### #####################
# Extra api methods # # Extra api methods #
......
KEY_EXPIRED = -2
KEY_NON_VOLATILE = -1
...@@ -6,30 +6,10 @@ from django.utils import six ...@@ -6,30 +6,10 @@ from django.utils import six
from django.utils.encoding import force_text, python_2_unicode_compatible from django.utils.encoding import force_text, python_2_unicode_compatible
from django.utils.six.moves.urllib.parse import parse_qs, urlparse from django.utils.six.moves.urllib.parse import parse_qs, urlparse
from redis._compat import unicode
from redis.connection import SSLConnection from redis.connection import SSLConnection
@python_2_unicode_compatible
class CacheKey(object):
"""
A stub string class that we can use to check if a key was created already.
"""
def __init__(self, key, versioned_key):
self._original_key = key
self._versioned_key = versioned_key
def __eq__(self, other):
return self._versioned_key == other
def __str__(self):
return force_text(self._versioned_key)
def __hash__(self):
return hash(self._versioned_key)
__repr__ = __str__
def get_servers(location): def get_servers(location):
"""Returns a list of servers given the server argument passed in from """Returns a list of servers given the server argument passed in from
Django. Django.
......
...@@ -5,11 +5,11 @@ setup( ...@@ -5,11 +5,11 @@ setup(
url="http://github.com/sebleier/django-redis-cache/", url="http://github.com/sebleier/django-redis-cache/",
author="Sean Bleier", author="Sean Bleier",
author_email="sebleier@gmail.com", author_email="sebleier@gmail.com",
version="1.8.1", version="2.0.0",
license="BSD", license="BSD",
packages=["redis_cache", "redis_cache.backends"], packages=["redis_cache", "redis_cache.backends"],
description="Redis Cache Backend for Django", description="Redis Cache Backend for Django",
install_requires=['redis==2.10.6'], install_requires=['redis<4.0'],
classifiers=[ classifiers=[
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 2.7", "Programming Language :: Python :: 2.7",
......
...@@ -21,6 +21,7 @@ import redis ...@@ -21,6 +21,7 @@ import redis
from tests.testapp.models import Poll, expensive_calculation from tests.testapp.models import Poll, expensive_calculation
from redis_cache.cache import RedisCache, pool from redis_cache.cache import RedisCache, pool
from redis_cache.constants import KEY_EXPIRED, KEY_NON_VOLATILE
from redis_cache.utils import get_servers, parse_connection_kwargs from redis_cache.utils import get_servers, parse_connection_kwargs
...@@ -299,13 +300,12 @@ class BaseRedisTestCase(SetupMixin): ...@@ -299,13 +300,12 @@ class BaseRedisTestCase(SetupMixin):
def test_set_expiration_timeout_zero(self): def test_set_expiration_timeout_zero(self):
key, value = self.cache.make_key('key'), 'value' key, value = self.cache.make_key('key'), 'value'
self.cache.set(key, value, timeout=0) self.cache.set(key, value, timeout=0)
self.assertIsNone(self.cache.get_client(key).ttl(key)) self.assertEqual(self.cache.get_client(key).ttl(key), KEY_EXPIRED)
self.assertIn(key, self.cache) self.assertNotIn(key, self.cache)
def test_set_expiration_timeout_negative(self): def test_set_expiration_timeout_negative(self):
key, value = self.cache.make_key('key'), 'value' key, value = self.cache.make_key('key'), 'value'
self.cache.set(key, value, timeout=-1) self.cache.set(key, value, timeout=-1)
self.assertIsNone(self.cache.get_client(key).ttl(key))
self.assertNotIn(key, self.cache) self.assertNotIn(key, self.cache)
def test_unicode(self): def test_unicode(self):
...@@ -481,9 +481,9 @@ class BaseRedisTestCase(SetupMixin): ...@@ -481,9 +481,9 @@ class BaseRedisTestCase(SetupMixin):
self.cache.set('b', 'b', 5) self.cache.set('b', 'b', 5)
self.cache.reinsert_keys() self.cache.reinsert_keys()
self.assertEqual(self.cache.get('a'), 'a') self.assertEqual(self.cache.get('a'), 'a')
self.assertGreater(self.cache.get_client('a').ttl(self.cache.make_key('a')), 1) self.assertGreater(self.cache.ttl('a'), 1)
self.assertEqual(self.cache.get('b'), 'b') self.assertEqual(self.cache.get('b'), 'b')
self.assertGreater(self.cache.get_client('b').ttl(self.cache.make_key('b')), 1) self.assertGreater(self.cache.ttl('a'), 1)
def test_get_or_set(self): def test_get_or_set(self):
...@@ -581,21 +581,21 @@ class BaseRedisTestCase(SetupMixin): ...@@ -581,21 +581,21 @@ class BaseRedisTestCase(SetupMixin):
self.cache.persist('a') self.cache.persist('a')
self.assertIsNone(self.cache.ttl('a')) self.assertIsNone(self.cache.ttl('a'))
def test_expire_no_expiry_to_expire(self): def test_touch_no_expiry_to_expire(self):
self.cache.set('a', 'a', timeout=None) self.cache.set('a', 'a', timeout=None)
self.cache.expire('a', 10) self.cache.touch('a', 10)
ttl = self.cache.ttl('a') ttl = self.cache.ttl('a')
self.assertAlmostEqual(ttl, 10) self.assertAlmostEqual(ttl, 10)
def test_expire_less(self): def test_touch_less(self):
self.cache.set('a', 'a', timeout=20) self.cache.set('a', 'a', timeout=20)
self.cache.expire('a', 10) self.cache.touch('a', 10)
ttl = self.cache.ttl('a') ttl = self.cache.ttl('a')
self.assertAlmostEqual(ttl, 10) self.assertAlmostEqual(ttl, 10)
def test_expire_more(self): def test_touch_more(self):
self.cache.set('a', 'a', timeout=10) self.cache.set('a', 'a', timeout=10)
self.cache.expire('a', 20) self.cache.touch('a', 20)
ttl = self.cache.ttl('a') ttl = self.cache.ttl('a')
self.assertAlmostEqual(ttl, 20) self.assertAlmostEqual(ttl, 20)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment