Source code for xarray_video.backend

###Video backend for xarray based on the xarray rasterio backend


import os
import warnings
import tempfile

import numpy as np
import numcodecs
import av

from xarray import DataArray, Dataset
from xarray.core import indexing
from xarray.core.utils import is_scalar
from xarray.backends.common import BackendArray
from xarray.backends.file_manager import CachingFileManager
from xarray.backends.locks import SerializableLock

from .exceptions import VideoReadError

VIDEO_LOCK = SerializableLock()
TEMPDIR = os.path.join(tempfile.gettempdir(), "xarray_video")
if not os.path.exists(TEMPDIR):
    os.mkdir(TEMPDIR)

compressor = numcodecs.registry.get_codec(dict(id="h264"))
lossless_compressor = numcodecs.registry.get_codec(dict(id="h264", crf=0))

def _key_length(key, length):
    if isinstance(key, slice):
        return len(range(*key.indices(length)))
    elif is_scalar(key):
        return 1
    else:
        return length


[docs]class VideoArrayWrapper(BackendArray): """A wrapper around video dataset objects""" def __init__(self, manager, lock, shape): self.manager = manager self.lock = lock reader = manager.acquire() stream = reader.streams.video[0] self._shape = shape self._dtype = np.dtype("uint8") ts0 = int((100 * av.time_base) / stream.average_rate) + stream.start_time reader.seek(ts0) for frame in reader.decode(stream): dt = frame.dts - stream.start_time self._can_seek = dt > 0 break manager.close() @property def dtype(self): return self._dtype @property def shape(self): return self._shape def _getitem(self, key): assert len(key) == 4, "video DataArrays should always be 4D" frame_key, y_key, x_key, band_key = key if isinstance(frame_key, slice): f0 = frame_key.start or 0 f1 = frame_key.stop or self._shape[0] fstep = frame_key.step or 1 elif is_scalar(frame_key): f0 = frame_key f1 = frame_key + 1 fstep = 1 else: f0 = 0 f1 = self._shape[0] fstep = 1 nf = len(range(f0, f1, fstep)) ny = _key_length(y_key, self._shape[1]) nx = _key_length(x_key, self._shape[2]) nb = _key_length(band_key, self._shape[3]) data = np.zeros((nf, ny, nx, nb), dtype="uint8") reader = self.manager.acquire() stream = reader.streams.video[0] if self._can_seek: ts0 = int((f0 * av.time_base) / stream.average_rate) + stream.start_time reader.seek(ts0) frame_start = -1 else: frame_start = 0 ind0 = 0 for i, frame in enumerate(reader.decode(video=0)): if frame_start < 0: dts = frame.dts if ( dts is None ): # Some packets at start have dts=None, same for fluxhing packets at end if packet.buffer_size > 0: dts = 0 else: dts = 1e10 frame_start = int(dts * stream.time_base * stream.rate) ind = frame_start + i if ind < f0: continue elif ind >= f1: break elif ind % fstep == 0: data[ind0] = frame.to_ndarray(format="rgb24")[y_key, x_key, band_key] ind0 += 1 self.manager.close() data = np.squeeze(data) return data def __getitem__(self, key): return indexing.explicit_indexing_adapter( key, self.shape, indexing.IndexingSupport.BASIC, self._getitem )
def _open_video(filename, mode): return av.open(filename, mode=mode) def _write_video(filename, array, fps=25, metadata={}): writer = av.open(filename, mode="w", format="mp4") nf, ny, nx, nb = array.shape stream = writer.add_stream("h264", rate=fps) stream.thread_type = "AUTO" stream.width = nx stream.height = ny stream.pix_fmt = "yuvj420p" for frame_i in array: frame = av.VideoFrame.from_ndarray(frame_i, format="rgb24") for packet in stream.encode(frame): writer.mux(packet) # Flush stream for packet in stream.encode(): writer.mux(packet) writer.close()
[docs]def open_video(filename, start_time=None, **kwargs): """Video file into an xarray dataset. This reads a video into an xarray dataset with the video in a DataArray. If a start time is provided, a time axis will be created for the frames. Args: filename (string): filename of videos to open start_time (:class:`numpy.datetime64`): Start time of video Returns: dataset (:class:`xarray.Dataset`): Dataset with video as a DataArray Raises: VideoReadError: Missing or incompatible files """ manager = CachingFileManager( _open_video, filename, lock=VIDEO_LOCK, mode="r", kwargs=kwargs, ) reader = manager.acquire() stream = reader.streams.video[0] codec = stream.codec_context frames = stream.frames # If the frame count is not in metadata, this is likely a matroska file. Then seeking will likely not work either # Solution is to scan the file using to demux to count the frames if frames == 0: for packet in reader.demux(stream): if packet.buffer_size > 0: frames += 1 fps = int(stream.average_rate) width = codec.width height = codec.height coords = {"channel": ["R", "G", "B"]} coords["pixel_x"] = np.arange(width) coords["pixel_y"] = np.arange(height) if start_time: times = np.datetime64(start_time) + np.arange( 0, 1000 * frames / fps, 1000 / fps ).astype("<m8[ms]") coords["time"] = ("frame", times) else: coords["frame"] = np.arange(frames) # Attributes attrs = {"fps": fps, "_video": codec.name} data = indexing.LazilyIndexedArray( VideoArrayWrapper( manager, VIDEO_LOCK, ( frames, height, width, 3, ), ) ) dataset = Dataset( data_vars={ "video": DataArray( data=data, dims=("frame", "pixel_y", "pixel_x", "channel"), coords=coords, attrs=attrs, ) }, ) if start_time: dataset = dataset.set_xindex("time") # Set the default zarr compressor and assign preferred chunk sizes dataset["video"].encoding = { "preferred_chunks": {"channel": 3, "pixel_y": height, "pixel_x": width}, } # Make the file closeable dataset.set_close(manager.close) return dataset