diff --git a/fastanime/libs/anime_provider/base_provider.py b/fastanime/libs/anime_provider/base_provider.py index 83b6c07..1bde420 100644 --- a/fastanime/libs/anime_provider/base_provider.py +++ b/fastanime/libs/anime_provider/base_provider.py @@ -18,8 +18,7 @@ class AnimeProvider: from ..common.requests_cacher import CachedRequestsSession self.session = CachedRequestsSession( - os.path.join(APP_CACHE_DIR, "anime_provider_cached_requests.db"), - os.path.join(APP_CACHE_DIR, "anime_provider_cached_requests.lock"), + os.path.join(APP_CACHE_DIR, "cached_requests.db") ) else: self.session = requests.session() diff --git a/fastanime/libs/common/requests_cacher.py b/fastanime/libs/common/requests_cacher.py index bcd871e..e4e369e 100644 --- a/fastanime/libs/common/requests_cacher.py +++ b/fastanime/libs/common/requests_cacher.py @@ -1,16 +1,14 @@ -import atexit import json import logging -import os -import pathlib import re -import sqlite3 import time from datetime import datetime from urllib.parse import urlencode import requests +from .sqlitedb_helper import SqliteDB + logger = logging.getLogger(__name__) caching_mimetypes = { @@ -51,7 +49,6 @@ class CachedRequestsSession(requests.Session): def __init__( self, cache_db_path: str, - cache_db_lock_file: str, max_lifetime: int = 604800, max_size: int = (1024**2) * 10, table_name: str = "fastanime_requests_cache", @@ -61,37 +58,30 @@ class CachedRequestsSession(requests.Session): ): super().__init__(*args, **kwargs) - self.clean_db = clean_db - self.lockfile_path = pathlib.Path(cache_db_lock_file) self.cache_db_path = cache_db_path self.max_lifetime = max_lifetime self.max_size = max_size self.table_name = table_name - logger.debug("Acquiring lock on the db") - self.acquirer_lock(cache_db_lock_file) - logger.debug("Successfully acquired lock on the db") - logger.debug("Getting connection to cache db") - self.connection = sqlite3.connect(self.cache_db_path) - logger.debug("Successfully gotten connection to cache db") + self.sqlite_db_connection = SqliteDB(self.cache_db_path) - logger.debug("Creating table if it does not exist") - self.connection.cursor().execute( - f""" - create table if not exists {self.table_name!r} ( - url text, - status_code integer, - request_headers text, - response_headers text, - data blob, - redirection_policy int, - cache_expiry integer - )""", - ) + # Prepare the cache table if it doesn't exist + self._create_cache_table() - atexit.register( - self._kill_connection_to_db, - self.connection, - ) + def _create_cache_table(self): + """Create cache table if it doesn't exist.""" + with self.sqlite_db_connection as conn: + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + url TEXT, + status_code INTEGER, + request_headers TEXT, + response_headers TEXT, + data BLOB, + redirection_policy INT, + cache_expiry INTEGER + )""" + ) def request( self, @@ -103,97 +93,95 @@ class CachedRequestsSession(requests.Session): *args, **kwargs, ): - # do a new request without caching if fresh: logger.debug("Executing fresh request") return super().request(method, url, params=params, *args, **kwargs) - # construct the exact url if it has params - if params is not None: + if params: url += "?" + urlencode(params) redirection_policy = int(kwargs.get("force_redirects", False)) - # fetch cached request from database - time_before_access_db = datetime.now() - cursor = self.connection.cursor() - logger.debug("Checking for existing request in cache db") - cursor.execute( - f""" - select - status_code, - request_headers, - response_headers, - data, - redirection_policy - from {self.table_name!r} - where - url = ? - and redirection_policy = ? - and cache_expiry > ? - """, - (url, redirection_policy, int(time.time())), - ) - cached_request = cursor.fetchone() - time_after_access_db = datetime.now() + with self.sqlite_db_connection as conn: + cursor = conn.cursor() + time_before_access_db = datetime.now() - # construct response from cached request - if cached_request is not None: - logger.debug("Found existing request in cache db") - ( - status_code, - request_headers, - response_headers, - data, - redirection_policy, - ) = cached_request - - response = requests.Response() - response.headers.update(json.loads(response_headers)) - response.status_code = status_code - response._content = data - - if "timeout" in kwargs: - kwargs.pop("timeout") - _request = requests.Request( - method, url, headers=json.loads(request_headers), *args, **kwargs - ) - response.request = _request.prepare() - response.elapsed = time_after_access_db - time_before_access_db - logger.debug(f"Cacher Elapsed: {response.elapsed}") - return response - - # construct a new response if the request does not already exist in the cache - # cache the response provided conditions to cache are met - response = super().request(method, url, *args, **kwargs) - if response.ok and ( - force_caching - or self.is_content_type_cachable( - response.headers.get("content-type"), caching_mimetypes - ) - and len(response.content) < self.max_size - ): - logger.debug("Caching current request") + logger.debug("Checking for existing request in cache") cursor.execute( - f"insert into {self.table_name!r} values (?, ?, ?, ?, ?, ?, ?)", - ( - url, - response.status_code, - json.dumps(dict(response.request.headers)), - json.dumps(dict(response.headers)), - response.content, - redirection_policy, - int(time.time()) + self.max_lifetime, - ), + f""" + SELECT + status_code, + request_headers, + response_headers, + data, + redirection_policy + FROM {self.table_name} + WHERE + url = ? + AND redirection_policy = ? + AND cache_expiry > ? + """, + (url, redirection_policy, int(time.time())), ) + cached_request = cursor.fetchone() + time_after_access_db = datetime.now() - self.connection.commit() + if cached_request: + logger.debug("Found existing request in cache") + ( + status_code, + request_headers, + response_headers, + data, + redirection_policy, + ) = cached_request - return response + response = requests.Response() + response.headers.update(json.loads(response_headers)) + response.status_code = status_code + response._content = data + + if "timeout" in kwargs: + kwargs.pop("timeout") + _request = requests.Request( + method, url, headers=json.loads(request_headers), *args, **kwargs + ) + response.request = _request.prepare() + response.elapsed = time_after_access_db - time_before_access_db + + return response + + # Perform the request and cache it + response = super().request(method, url, *args, **kwargs) + if response.ok and ( + force_caching + or self.is_content_type_cachable( + response.headers.get("content-type"), caching_mimetypes + ) + and len(response.content) < self.max_size + ): + logger.debug("Caching the current request") + cursor.execute( + f""" + INSERT INTO {self.table_name} + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + url, + response.status_code, + json.dumps(dict(response.request.headers)), + json.dumps(dict(response.headers)), + response.content, + redirection_policy, + int(time.time()) + self.max_lifetime, + ), + ) + + return response @staticmethod def is_content_type_cachable(content_type, caching_mimetypes): - """Checks whether the given encoding is supported by the cacher""" + """Checks whether the given encoding is supported by the cacher""" if content_type is None: return True @@ -205,40 +193,9 @@ class CachedRequestsSession(requests.Session): content in caching_mimetypes[mime] for content in contents.split("+") ) - def kill_connection_to_db(self): - self._kill_connection_to_db(self.connection) - self.lockfile_path.unlink() - atexit.unregister(self.lockfile_path.unlink) - atexit.unregister(self._kill_connection_to_db) - - def _kill_connection_to_db(self, connection): - connection.commit() - connection.close() - if self.clean_db: - os.remove(self.cache_db_path) - - def acquirer_lock(self, lock_file: str): - """the function creates a lock file preventing other instances of the cacher from running at the same time""" - - if self.lockfile_path.exists(): - with self.lockfile_path.open("r") as f: - pid = f.read() - - raise RuntimeError( - f"This instance of {__class__.__name__!r} is already running with PID: {pid}. " - "Sqlite3 does not support multiple connections to the same database. " - "If you are sure that no other instance of this class is running, " - f"delete the lock file at {self.lockfile_path.as_posix()!r} and try again." - ) - - with self.lockfile_path.open("w") as f: - f.write(str(os.getpid())) - - atexit.register(self.lockfile_path.unlink) - if __name__ == "__main__": - with CachedRequestsSession("cache.db", "cache.lockfile") as session: + with CachedRequestsSession("cache.db") as session: response = session.get( "https://google.com", ) diff --git a/fastanime/libs/common/sqlitedb_helper.py b/fastanime/libs/common/sqlitedb_helper.py new file mode 100644 index 0000000..7549b94 --- /dev/null +++ b/fastanime/libs/common/sqlitedb_helper.py @@ -0,0 +1,34 @@ +import logging +import sqlite3 +import time + +logger = logging.getLogger(__name__) + + +class SqliteDB: + def __init__(self, db_path: str) -> None: + self.db_path = db_path + self.connection = sqlite3.connect(self.db_path) + logger.debug("Enabling WAL mode for concurrent access") + self.connection.execute("PRAGMA journal_mode=WAL;") + self.connection.close() + self.connection = None + + def __enter__(self): + logger.debug("Starting new connection...") + start_time = time.time() + self.connection = sqlite3.connect(self.db_path) + logger.debug( + "Successfully got a new connection in {} seconds".format( + time.time() - start_time + ) + ) + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.connection: + logger.debug("Closing connection to cache db") + self.connection.commit() + self.connection.close() + self.connection = None + logger.debug("Successfully closed connection to cache db")