Commit 94eab537 authored by Sean Bleier's avatar Sean Bleier

Add docstring and test for get_or_set.

parent 18a58d06
...@@ -406,13 +406,28 @@ class BaseRedisCache(BaseCache): ...@@ -406,13 +406,28 @@ class BaseRedisCache(BaseCache):
) )
@get_client(write=True) @get_client(write=True)
def get_or_set(self, client, key, func, timeout=DEFAULT_TIMEOUT, lock_timeout=None, stale_cache_timeout=0): def get_or_set(
self,
client,
key,
func,
timeout=DEFAULT_TIMEOUT,
lock_timeout=None,
stale_cache_timeout=None):
"""Get a value from the cache or call `func` to set it and return it. """Get a value from the cache or call `func` to set it and return it.
This implementation is slightly more advanced that Django's. It provides thundering herd This implementation is slightly more advanced that Django's. It provides thundering herd
protection that prevents multiple threads/processes from calling the value-generating protection, which prevents multiple threads/processes from calling the value-generating
function too much. function too much.
There are three timeouts you can specify:
`timeout`: Time in seconds that value at `key` is considered fresh.
`lock_timeout`: Time in seconds that the lock will stay active and prevent other threads or
processes from acquiring the lock.
`stale_cache_timeout`: Time in seconds that the stale cache will remain after the key has
expired. If `None` is specified, the stale value will remain indefinitely.
""" """
if not callable(func): if not callable(func):
raise Exception("Must pass in a callable") raise Exception("Must pass in a callable")
...@@ -422,6 +437,7 @@ class BaseRedisCache(BaseCache): ...@@ -422,6 +437,7 @@ class BaseRedisCache(BaseCache):
is_fresh = self._get(client, fresh_key) is_fresh = self._get(client, fresh_key)
value = self._get(client, key) value = self._get(client, key)
if is_fresh: if is_fresh:
return value return value
...@@ -436,9 +452,12 @@ class BaseRedisCache(BaseCache): ...@@ -436,9 +452,12 @@ class BaseRedisCache(BaseCache):
except Exception: except Exception:
raise raise
else: else:
key_timeout = (
None if stale_cache_timeout is None else timeout + stale_cache_timeout
)
pipeline = client.pipeline() pipeline = client.pipeline()
pipeline.set(key, self.prep_value(value), timeout + stale_cache_timeout) pipeline.set(key, self.prep_value(value), key_timeout)
pipeline.set(fresh_key, self.prep_value(1), timeout) pipeline.set(fresh_key, 1, timeout)
pipeline.execute() pipeline.execute()
finally: finally:
lock.release() lock.release()
......
...@@ -4,6 +4,7 @@ from __future__ import unicode_literals ...@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from hashlib import sha1 from hashlib import sha1
import os import os
import subprocess import subprocess
import threading
import time import time
...@@ -510,6 +511,66 @@ class BaseRedisTestCase(SetupMixin): ...@@ -510,6 +511,66 @@ class BaseRedisTestCase(SetupMixin):
self.assertEqual(expensive_function.num_calls, 2) self.assertEqual(expensive_function.num_calls, 2)
self.assertEqual(value, 42) self.assertEqual(value, 42)
def test_get_or_set_serving_from_stale_value(self):
def expensive_function(x):
time.sleep(.5)
expensive_function.num_calls += 1
return x
expensive_function.num_calls = 0
self.assertEqual(expensive_function.num_calls, 0)
results = {}
def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_timeout):
value = self.cache.get_or_set(
'key',
lambda: expensive_function(return_value),
timeout,
lock_timeout,
stale_cache_timeout
)
results[thread_id] = value
thread_0 = threading.Thread(target=thread_worker, args=(0, 'a', 1, None, 1))
thread_1 = threading.Thread(target=thread_worker, args=(1, 'b', 1, None, 1))
thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1))
thread_3 = threading.Thread(target=thread_worker, args=(3, 'd', 1, None, 1))
thread_4 = threading.Thread(target=thread_worker, args=(4, 'e', 1, None, 1))
# First thread should complete and return its value
thread_0.start() # t = 0, valid from t = .5 - 1.5, stale from t = 1.5 - 2.5
# Second thread will start while the first thread is still working and return None.
time.sleep(.25) # t = .25
thread_1.start()
# Third thread will start after the first value is computed, but before it expires.
# its value.
time.sleep(.5) # t = .75
thread_2.start()
# Fourth thread will start after the first value has expired and will re-compute its value.
# valid from t = 2.25 - 3.25, stale from t = 3.75 - 4.75.
time.sleep(1) # t = 1.75
thread_3.start()
# Fifth thread will start after the fourth thread has started to compute its value, but
# before the first thread's stale cache has expired.
time.sleep(.25) # t = 2
thread_4.start()
thread_0.join()
thread_1.join()
thread_2.join()
thread_3.join()
thread_4.join()
self.assertEqual(results, {
0: 'a',
1: None,
2: 'a',
3: 'd',
4: 'a'
})
def assertMaxConnection(self, cache, max_num): def assertMaxConnection(self, cache, max_num):
for client in cache.clients.values(): for client in cache.clients.values():
self.assertLessEqual(client.connection_pool._created_connections, max_num) self.assertLessEqual(client.connection_pool._created_connections, max_num)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment