#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""This module contains all things related to download monitoring and download queues in pygeodes
"""
# -----------------------------------------------------------------------------
# Copyright (c) 2024, CNES
#
# REFERENCES:
# https://cnes.fr/
# -----------------------------------------------------------------------------
# stdlib imports -------------------------------------------------------
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
import os
from typing import Literal, List
from time import sleep
from uuid import uuid4
from pathlib import Path
import warnings
# third-party imports -----------------------------------------------
from rich.live import Live
from rich.table import Table
from tqdm import tqdm
# local imports ---------------------------------------------------
from pygeodes.utils.io import (
load_json,
file_exists,
write_json,
compute_md5,
)
from pygeodes.utils.consts import PROFILE_DEFAULT_PATH
from pygeodes.utils.dataclasses_utils import class_from_args
from pygeodes.utils.stac import Item
from pygeodes.utils.datetime_utils import (
datetime_to_str,
str_to_datetime,
a_week_ago,
today,
time_ago,
)
from pygeodes.utils.logger import logger
from pygeodes.utils.download import correct_download_tld
from pygeodes.utils.s3 import download_item as download_item_from_s3
[docs]@dataclass
class Download:
"""This class represents a pending, current or finished Geodes download"""
url: str
destination: str = None
_started_at: str = None
_completed_at: str = None
_id: str = None
def __post_init__(self):
"""This function is executed after instanciating a Download objects and corrects some behaviors"""
if self._id is None:
self._id = str(uuid4())
if self.destination:
self.destination = str(self.destination)
if type(self._started_at) is datetime:
self._started_at = datetime_to_str(self._started_at)
if type(self._completed_at) is datetime:
self._completed_at = datetime_to_str(self._completed_at)
[docs] def start(self):
"""This method starts the download (sets _started_at to now)"""
if self._started_at is None:
self._started_at = datetime_to_str(datetime.now())
[docs] @classmethod
def from_dict(cls, dico: dict):
"""This function creates a Download object from a dictionnary
Parameters
----------
dico : dict
a dictionnary representing a Download object
"""
return class_from_args(cls, dico)
@property
def started_at(self):
"""This returns _started_at as a datetime object
Returns
-------
datetime
started_at
"""
if self._started_at is not None:
return str_to_datetime(self._started_at)
else:
return None
@property
def completed_at(self):
"""This returns _completed_at as a datetime object
Returns
-------
datetime
completed_at
"""
if self._completed_at is not None:
return str_to_datetime(self._completed_at)
else:
return None
@property
def time_taken(self):
"""This returns the time taken to download if the download is completed
Returns
-------
datetime
the time taken
"""
if not self.completed_at:
return None
else:
return self.completed_at - self.started_at
def to_dict(self):
dico = asdict(self)
dico["destination"] = str(dico["destination"])
return dico
[docs] def complete(self):
"""This sets the _completed_at attr to now"""
self._completed_at = datetime_to_str(datetime.now())
def __eq__(self, obj):
return obj._id == self._id
def __hash__(self):
return hash(self._id)
[docs]class Profile:
_filepath = PROFILE_DEFAULT_PATH
[docs] def __init__(self):
self.downloads = {}
def get_download_from_uuid(self, uuid):
return self.downloads.get(uuid)
@classmethod
def reset(cls):
if file_exists(cls._filepath, False):
os.remove(cls._filepath)
@classmethod
def load(cls):
if not file_exists(cls._filepath, False):
write_json({"downloads": {}}, cls._filepath)
content = load_json(cls._filepath)
obj = cls()
obj.downloads = {
key: Download.from_dict(value)
for key, value in content.get("downloads").items()
}
return obj
def add_download(self, download: Download):
self.downloads[download._id] = download
@property
def nb_downloads(self):
return len(self.downloads)
@classmethod
def watch_downloads(cls, refresh_rate: float = 5, simplified: bool = False):
try:
if simplified:
while True:
profile = Profile.load()
current_downloads = profile.current_downloads()
pending_downloads = profile.pending_downloads()
print(
f"[{datetime.now()}] Current downloads : {len(current_downloads)} | Pending downloads : {len(pending_downloads)}"
)
sleep(refresh_rate)
else:
started = datetime.now()
def create_table():
table = Table(
caption=f"[bold orange1]Current[/bold orange1], [bold red]pending[/bold red] and [bold green]just finished downloads[/bold green] (refreshing every {refresh_rate} seconds, started watching {time_ago(started)})",
expand=True,
show_lines=True,
)
table.add_column("Download ID")
table.add_column("URL")
table.add_column("Destination")
table.add_column("Started at")
table.add_column("Completed at")
return table
with Live(
create_table(), refresh_per_second=refresh_rate
) as live:
while True:
table = create_table()
profile = cls.load()
pending_downloads = profile.pending_downloads()
current_downloads = profile.current_downloads()
just_finished = profile.just_finished_downloads(
nb_seconds=refresh_rate * 2
) # times two to be sure we get to see the finished downloads
for download in just_finished:
elements = [
download._id,
download.url,
download.destination,
f"{str(download.started_at)} ({time_ago(download.started_at)})",
f"{str(download.completed_at)} ({time_ago(download.completed_at)})",
]
elements = [
f"[green]{element}" for element in elements
]
table.add_row(*elements)
for download in current_downloads:
elements = [
download._id,
download.url,
download.destination,
f"{str(download.started_at)} ({time_ago(download.started_at)})",
"",
]
elements = [
f"[orange1]{element}" for element in elements
]
table.add_row(*elements)
for download in pending_downloads:
elements = [
download._id,
download.url,
download.destination,
"",
"",
]
elements = [
f"[red]{element}" for element in elements
]
table.add_row(*elements)
live.update(table)
sleep(refresh_rate)
except KeyboardInterrupt:
print("Download watching interrupted")
def current_downloads(self):
return {
download
for download in self.downloads.values()
if download.completed_at is None and download.started_at is not None
}
def pending_downloads(self):
return {
download
for download in self.downloads.values()
if download.started_at is None
}
def just_finished_downloads(self, nb_seconds):
if nb_seconds < 1:
nb_seconds = 1
def is_recent(date):
now = datetime.now()
end = now
begin = now - timedelta(seconds=nb_seconds)
return begin <= date <= end
finished = {
download
for download in self.downloads.values()
if download.completed_at is not None
}
return {
download
for download in finished
if is_recent(download.completed_at)
}
def recent_downloads(self, since: Literal["today", "this_week"] = "today"):
if since == "today":
begin, end = today()
elif since == "this_week":
begin = a_week_ago()
end = datetime.now()
else:
raise Exception(
f"Please use one of 'today','this_week' as 'since' parameter"
)
def is_in_interval(obj):
if obj.completed_at:
return (
begin <= obj.completed_at <= end
or begin <= obj.started_at <= end
)
else:
return begin <= obj.started_at <= end
print(f"Displaying downloads from {begin} to {end}")
return [
download
for download in self.downloads.values()
if is_in_interval(download)
]
def to_dict(self):
dico = {}
dico["nb_downloads"] = self.nb_downloads
dico["downloads"] = {
key: value.to_dict() for key, value in self.downloads.items()
}
return dico
def save(self):
logger.debug(f"saving profile as {self.to_dict()}")
write_json(self.to_dict(), self._filepath)
[docs]def load_profile_and_save_download(d: Download):
profile = Profile.load()
profile.add_download(d)
profile.save()
[docs]class DownloadQueue:
[docs] def __init__(self, items: List[Item], download_dir: str = None):
self.items = items
self.downloads_objects = {}
from pygeodes.geodes import Geodes
self.geodes_instance = Geodes.get_last_instance()
if download_dir:
self.download_dir = download_dir
else:
self.download_dir = self.geodes_instance.conf.download_dir
def _download_item(self, item: Item):
download_for_profile = self.downloads_objects[item]
outfile = str(Path(self.download_dir).joinpath(item.data_asset.title))
download_for_profile.destination = outfile
if self.geodes_instance.s3_client is not None:
download_for_profile.url = item.find("accessService:endpointURL")
download_for_profile.start()
load_profile_and_save_download(download_for_profile)
outfile_really_used = download_item_from_s3(
self.geodes_instance.s3_client, item, outfile=outfile
)
else:
download_url = correct_download_tld(
item.data_asset.href
) # temp as top level domains aren't ok
download_for_profile.url = download_url
download_for_profile.start()
load_profile_and_save_download(download_for_profile)
outfile_really_used = self.geodes_instance.request_maker.download_file( # because outfile may change if already used or sth like that
download_url,
outfile,
checksum=item.data_asset_checksum,
checksum_error=self.geodes_instance.conf.checksum_error,
verbose=False,
)
profile = Profile.load()
download_for_profile = profile.get_download_from_uuid(
download_for_profile._id
)
download_for_profile.destination = outfile_really_used
self.downloads_objects[item] = download_for_profile
download_for_profile.complete()
profile.save()
def _init_downloads(self):
if len(self.downloads_objects) == 0:
profile = Profile.load()
for item in self.items:
d = Download(
url=item.data_asset.href, destination=None, _started_at=None
)
self.downloads_objects[item] = d
profile.add_download(d)
profile.save()
def check_integrity(self):
for item, download in self.downloads_objects.items():
if download.completed_at is not None:
checksum = item.data_asset.checksum
if file_exists(download.destination, False):
checksum_at_destination = compute_md5(download.destination)
if checksum != checksum_at_destination:
warnings.warn(
f"File for item {item} doesn't correspond to checksum"
)
else:
warnings.warn(
f"File is absent at destination for item {item}"
)
else:
warnings.warn(f"Download for item {item} was not completed")
def run(self):
self._init_downloads()
downloads = self.downloads_objects.values()
if all([value.completed_at is not None for value in downloads]):
print(f"All downloads are completed in {self.download_dir}")
return
completed = [
download
for download in downloads
if download.completed_at is not None
]
pending = [
download
for download in downloads
if download.completed_at is None and download.started_at is None
]
print(
f"Queue (started {datetime.now()}) | Completed : {len(completed)} | Pending : {len(pending)}"
)
for item in tqdm(self.items):
obj = self.downloads_objects[item]
if obj.completed_at is None:
if (
obj.started_at is not None and obj.destination is not None
): # was started but not completed, file is not complete, so we start again
if file_exists(obj.destination, False):
os.remove(obj.destination)
self._download_item(item)