Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cache): handle get_from_cache=None and ensure directory exists #544

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 67 additions & 56 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
from collections import defaultdict
from datetime import datetime
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Any

import click
import requests
Expand All @@ -21,6 +21,7 @@
from packaging.utils import canonicalize_name
from packaging.version import parse as parse_version, Version
from pydantic.json import pydantic_encoder
from filelock import FileLock

from safety_schemas.models import Ecosystem, FileType

Expand All @@ -41,34 +42,38 @@
LOG = logging.getLogger(__name__)


def get_from_cache(db_name, cache_valid_seconds=0, skip_time_verification=False):
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE) as f:
try:
data = json.loads(f.read())
if db_name in data:
def get_from_cache(db_name: str, cache_valid_seconds: int = 0, skip_time_verification: bool = False) -> Optional[Dict[str, Any]]:
cache_file_lock = f"{DB_CACHE_FILE}.lock"
os.makedirs(os.path.dirname(cache_file_lock), exist_ok=True)
lock = FileLock(cache_file_lock, timeout=10)
with lock:
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE) as f:
try:
data = json.loads(f.read())
if db_name in data:

if "cached_at" in data[db_name]:
if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification:
LOG.debug('Getting the database from cache at %s, cache setting: %s',
data[db_name]["cached_at"], cache_valid_seconds)

try:
data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com"
except KeyError as e:
pass
if "cached_at" in data[db_name]:
if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification:
LOG.debug('Getting the database from cache at %s, cache setting: %s',
data[db_name]["cached_at"], cache_valid_seconds)

return data[db_name]["db"]
try:
data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com"
except KeyError as e:
pass

LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"])
else:
LOG.debug('There is not the cached_at key in %s database', data[db_name])
return data[db_name]["db"]

except json.JSONDecodeError:
LOG.debug('JSONDecodeError trying to get the cached database.')
else:
LOG.debug("Cache file doesn't exist...")
return False
LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"])
else:
LOG.debug('There is not the cached_at key in %s database', data[db_name])

except json.JSONDecodeError:
LOG.debug('JSONDecodeError trying to get the cached database.')
else:
LOG.debug("Cache file doesn't exist...")
return None


def write_to_cache(db_name, data):
Expand All @@ -95,25 +100,31 @@ def write_to_cache(db_name, data):
if exc.errno != errno.EEXIST:
raise

with open(DB_CACHE_FILE, "r") as f:
try:
cache = json.loads(f.read())
except json.JSONDecodeError:
LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.')
cache_file_lock = f"{DB_CACHE_FILE}.lock"
lock = FileLock(cache_file_lock, timeout=10)
with lock:
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE, "r") as f:
try:
cache = json.loads(f.read())
except json.JSONDecodeError:
LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.')
cache = {}
else:
cache = {}

with open(DB_CACHE_FILE, "w") as f:
cache[db_name] = {
"cached_at": time.time(),
"db": data
}
f.write(json.dumps(cache))
LOG.debug('Safety updated the cache file for %s database.', db_name)
with open(DB_CACHE_FILE, "w") as f:
cache[db_name] = {
"cached_at": time.time(),
"db": data
}
f.write(json.dumps(cache))
LOG.debug('Safety updated the cache file for %s database.', db_name)


def fetch_database_url(session, mirror, db_name, cached, telemetry=True,
ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True):
headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value}
headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value}

if cached and from_cache:
cached_data = get_from_cache(db_name=db_name, cache_valid_seconds=cached)
Expand All @@ -122,13 +133,13 @@ def fetch_database_url(session, mirror, db_name, cached, telemetry=True,
return cached_data
url = mirror + db_name


telemetry_data = {
'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry),
'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry),
default=pydantic_encoder)}

try:
r = session.get(url=url, timeout=REQUEST_TIMEOUT,
r = session.get(url=url, timeout=REQUEST_TIMEOUT,
headers=headers, params=telemetry_data)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
Expand Down Expand Up @@ -205,10 +216,10 @@ def fetch_database_file(path: str, db_name: str, cached = 0,

if not full_path.exists():
raise DatabaseFileNotFoundError(db=path)

with open(full_path) as f:
data = json.loads(f.read())

if cached:
LOG.info('Writing %s to cache because cached value was %s', db_name, cached)
write_to_cache(db_name, data)
Expand All @@ -226,7 +237,7 @@ def is_valid_database(db) -> bool:
return False


def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
ecosystem: Optional[Ecosystem] = None, from_cache=True):

if session.is_using_auth_credentials():
Expand All @@ -242,7 +253,7 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
if is_a_remote_mirror(mirror):
if ecosystem is None:
ecosystem = Ecosystem.PYTHON
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache)
else:
data = fetch_database_file(mirror, db_name=db_name, cached=cached,
Expand Down Expand Up @@ -562,16 +573,16 @@ def compute_sec_ver(remediations, packages: Dict[str, Package], secure_vulns_by_
secure_v = compute_sec_ver_for_user(package=pkg, secure_vulns_by_user=secure_vulns_by_user, db_full=db_full)

rem['closest_secure_version'] = get_closest_ver(secure_v, version, spec)

upgrade = rem['closest_secure_version'].get('upper', None)
downgrade = rem['closest_secure_version'].get('lower', None)
recommended_version = None

if upgrade:
recommended_version = upgrade
elif downgrade:
recommended_version = downgrade

rem['recommended_version'] = recommended_version
rem['other_recommended_versions'] = [other_v for other_v in secure_v if
other_v != str(recommended_version)]
Expand Down Expand Up @@ -645,12 +656,12 @@ def process_fixes(files, remediations, auto_remediation_limit, output, no_output

def process_fixes_scan(file_to_fix, to_fix_spec, auto_remediation_limit, output, no_output=True, prompt=False):
to_fix_remediations = []

def get_remmediation_from(spec):
upper = None
lower = None
recommended = None

try:
upper = Version(spec.remediation.closest_secure.upper) if spec.remediation.closest_secure.upper else None
except Exception as e:
Expand All @@ -664,15 +675,15 @@ def get_remmediation_from(spec):
try:
recommended = Version(spec.remediation.recommended)
except Exception as e:
LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True)
LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True)

return {
"vulnerabilities_found": spec.remediation.vulnerabilities_found,
"version": next(iter(spec.specifier)).version if spec.is_pinned() else None,
"requirement": spec,
"more_info_url": spec.remediation.more_info_url,
"closest_secure_version": {
'upper': upper,
'upper': upper,
'lower': lower
},
"recommended_version": recommended,
Expand All @@ -690,7 +701,7 @@ def get_remmediation_from(spec):
'files': {str(file_to_fix.location): {'content': None, 'fixes': {'TO_SKIP': [], 'TO_APPLY': [], 'TO_CONFIRM': []}, 'supported': False, 'filename': file_to_fix.location.name}},
'dependencies': defaultdict(dict),
}

fixes = apply_fixes(requirements, output, no_output, prompt, scan_flow=True, auto_remediation_limit=auto_remediation_limit)

return fixes
Expand Down Expand Up @@ -822,7 +833,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto
for name, data in requirements['files'].items():
output = [('', {}),
(f"Analyzing {name}... [{get_fix_opt_used_msg(auto_remediation_limit)} limit]", {'styling': {'bold': True}, 'start_line_decorator': '->', 'indent': ' '})]

r_skip = data['fixes']['TO_SKIP']
r_apply = data['fixes']['TO_APPLY']
r_confirm = data['fixes']['TO_CONFIRM']
Expand Down Expand Up @@ -901,7 +912,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto
else:
not_supported_filename = data.get('filename', name)
output.append(
(f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.",
(f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.",
{'start_line_decorator': ' -', 'indent': ' '}))
output.append(('', {}))

Expand Down Expand Up @@ -999,7 +1010,7 @@ def review(*, report=None, params=None):

@sync_safety_context
def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True):

if db_mirror:
mirrors = [db_mirror]
else:
Expand Down
Loading
Loading