add the function to check cache data for tampering.

This commit is contained in:
Yoshihiro OKUMURA 2023-05-08 17:04:03 +09:00
parent 819c4a6a07
commit fb8dfbef10
Signed by: orrisroot
GPG Key ID: 470AA444C92904B2

View File

@ -1,5 +1,6 @@
import dataclasses
import fcntl
import hashlib
import json
import os
@ -7,6 +8,7 @@ from pydantic import ValidationError
from pydantic.dataclasses import dataclass
from pydantic.tools import parse_obj_as
from mdrsclient.exceptions import UnexpectedException
from mdrsclient.models import Laboratories, Token, User
from mdrsclient.settings import CONFIG_DIR_PATH
@ -16,6 +18,18 @@ class CacheData:
user: User | None
token: Token | None
laboratories: Laboratories
digest: str | None
def calc_digest(self) -> str:
return hashlib.sha256(
json.dumps(
[
None if self.user is None else dataclasses.asdict(self.user),
None if self.token is None else dataclasses.asdict(self.token),
dataclasses.asdict(self.laboratories),
]
).encode("utf-8")
).hexdigest()
class CacheFile:
@ -28,7 +42,7 @@ class CacheFile:
self.serial = -1
self.cache_dir = os.path.join(CONFIG_DIR_PATH, "cache")
self.cache_file = os.path.join(self.cache_dir, remote + ".json")
self.data = CacheData(user=None, token=None, laboratories=Laboratories([]))
self.data = CacheData(user=None, token=None, laboratories=Laboratories([]), digest=None)
def dump(self) -> CacheData | None:
self.__load()
@ -88,21 +102,28 @@ class CacheFile:
if self.serial != serial:
try:
with open(self.cache_file) as f:
self.data = parse_obj_as(CacheData, json.load(f))
except ValidationError:
data = parse_obj_as(CacheData, json.load(f))
print(f"{data.digest} : {data.calc_digest()}")
if data.digest != data.calc_digest():
raise UnexpectedException("Cache data has been broken.")
self.data = data
except (ValidationError, UnexpectedException):
self.__clear()
self.__save()
else:
self.serial = serial
else:
self.data.token = None
self.__clear()
self.serial = -1
def __save(self) -> None:
self.__ensure_cache_dir()
with open(self.cache_file, "w") as f:
fcntl.flock(f, fcntl.LOCK_EX)
self.data.digest = self.data.calc_digest()
f.write(json.dumps(dataclasses.asdict(self.data)))
stat = os.stat(self.cache_file)
self.serial = hash((stat.st_uid, stat.st_gid, stat.st_mode, stat.st_size, stat.st_mtime))
# ensure file is secure.
os.chmod(self.cache_file, 0o600)