introduce mutex to avoid multiple token refresh when using threads.

This commit is contained in:
Yoshihiro OKUMURA 2023-05-09 19:30:07 +09:00
parent 8ee220137f
commit 35defa39a6
Signed by: orrisroot
GPG Key ID: 470AA444C92904B2
3 changed files with 20 additions and 12 deletions

View File

@ -4,6 +4,8 @@ from mdrsclient.exceptions import UnauthorizedException
def token_check(connection: MDRSConnection) -> None: def token_check(connection: MDRSConnection) -> None:
try:
connection.lock.acquire()
if connection.token is not None: if connection.token is not None:
if connection.token.is_refresh_required: if connection.token.is_refresh_required:
user_api = UserApi(connection) user_api = UserApi(connection)
@ -13,3 +15,5 @@ def token_check(connection: MDRSConnection) -> None:
connection.logout() connection.logout()
elif connection.token.is_expired: elif connection.token.is_expired:
connection.logout() connection.logout()
finally:
connection.lock.release()

View File

@ -1,3 +1,5 @@
import threading
import requests import requests
from mdrsclient.cache import CacheFile from mdrsclient.cache import CacheFile
@ -8,12 +10,14 @@ from mdrsclient.models import Laboratories, Token, User
class MDRSConnection: class MDRSConnection:
url: str url: str
session: requests.Session session: requests.Session
lock: threading.Lock
__cache: CacheFile __cache: CacheFile
def __init__(self, remote: str, url: str) -> None: def __init__(self, remote: str, url: str) -> None:
super().__init__() super().__init__()
self.url = url self.url = url
self.session = requests.Session() self.session = requests.Session()
self.lock = threading.Lock()
self.__cache = CacheFile(remote) self.__cache = CacheFile(remote)
self.__prepare_headers() self.__prepare_headers()

View File

@ -26,16 +26,16 @@ class Token:
@property @property
def is_expired(self) -> bool: def is_expired(self) -> bool:
now = int(time.time()) + 10 now = int(time.time())
refresh_decoded = self.__decode(self.refresh) refresh_decoded = self.__decode(self.refresh)
return now > refresh_decoded.exp return (now - 10) > refresh_decoded.exp
@property @property
def is_refresh_required(self) -> bool: def is_refresh_required(self) -> bool:
now = int(time.time()) + 10 now = int(time.time())
access_decoded = self.__decode(self.access) access_decoded = self.__decode(self.access)
refresh_decoded = self.__decode(self.refresh) refresh_decoded = self.__decode(self.refresh)
return now > access_decoded.exp and now < refresh_decoded.exp return (now + 10) > access_decoded.exp and (now - 10) < refresh_decoded.exp
def __decode(self, token: str) -> DecodedJWT: def __decode(self, token: str) -> DecodedJWT:
data = jwt.decode(token, options={"verify_signature": False}) data = jwt.decode(token, options={"verify_signature": False})