class ExtendedMLflowExperiment:
"""MLflow experiment made to mimic tensorboard experiments."""
def __init__(self, mlflow_client: MlflowClient, run_id: str):
self._mlflow_client = mlflow_client
self._run_id = run_id
self._tempdir = TemporaryDirectory()
def _get_tmp_prefix_for_step(self, step: int):
return os.path.join(self._tempdir.name, f"{step:07d}")
def add_video(self, vid_tensor, fps: int, tag: str, global_step: int):
path = tag # TF paths are typically split using "/"
filename = write_video_tensor(
self._get_tmp_prefix_for_step(global_step), prepare_video_tensor(vid_tensor), fps
)
self._mlflow_client.log_artifact(self._run_id, filename, path)
os.remove(filename)
def add_image(self, img_tensor: torch.Tensor, dataformats: str, tag: str, global_step: int):
path = tag
filename = write_image_tensor(
self._get_tmp_prefix_for_step(global_step),
prepare_image_tensor(img_tensor, dataformats=dataformats),
)
self._mlflow_client.log_artifact(self._run_id, filename, path)
os.remove(filename)
def add_images(self, img_tensor, dataformats: str, tag: str, global_step: int):
# Internally works by having an additional N dimension in `dataformats`.
self.add_image(img_tensor, dataformats, tag, global_step)
def add_figure(self, figure, close: bool, tag: str, global_step: int):
if isinstance(figure, list):
self.add_image(
figure_to_image(figure, close),
dataformats="NCHW",
tag=tag,
global_step=global_step,
)
else:
self.add_image(
figure_to_image(figure, close),
dataformats="CHW",
tag=tag,
global_step=global_step,
)
def __getattr__(self, name):
"""Fallback to mlflow client for missing attributes.
Fallback to make the experiment object still behave like the regular MLflow client. While
this is suboptimal, it does allow us to save a lot of handcrafted code by relying on
inheritance and pytorch lightings implementation of the MLflow logger.
"""
return getattr(self._mlflow_client, name)