Skip to content
Snippets Groups Projects
Commit c939f2ba authored by Valentin Simon Lüdke's avatar Valentin Simon Lüdke
Browse files

Add upscaling offrame of flow direction enable use of highresolution dem

parent d32b5550
No related branches found
No related tags found
3 merge requests!16Generate classical mhm setups from classical setup,!15Draft: Cut classical mhm setups,!8Draft: Resolve "Pre-Proc: incorporate tools for creating global setups"
Pipeline #257941 failed with stages
in 31 seconds
......@@ -89,11 +89,18 @@ def add_args(parser):
"""coordinates in the form of 'lon_min,lon_max,lat_min,lat_max,resolution_l0'"""
),
)
required_args.add_argument(
"--l1_resolution",
required=False,
default=None,
help=("""Resolution of the mHM target grid in degrees. If given the grid will be upscaled to this resolution."""),
)
parser.add_argument(
"--mask_file",
default=None,
help=("Path where to save the mask file"),
)
parser.add_argument("--frame", default=0, type=int, help=("Creates a frame of nonflow cells around the domain to enable non global domains in ulysses mrm which connects the eastern and western boundaries."))
parser.add_argument("--log_level", default="INFO", type=str, help=("Logging level"))
......@@ -133,4 +140,6 @@ def run(args):
gauge_coords=gauge_coords,
coordinate_slices=coordinate_slices,
mask_file=args.mask_file,
target_resolution=args.l1_resolution,
frame=args.frame
)
......@@ -100,6 +100,7 @@ class Catchment:
transform=None,
out_var_name=None,
do_shift=False,
target_resolution=None,
**kwargs,
):
self.flwdir = None
......@@ -107,9 +108,12 @@ class Catchment:
self.upgrid = None
self.uparea_grid = None
self.grdare = None
self.input_ds = None
self.elevtn = None
self._fdir = None
self.ftype = ftype
self.catchment_mask = None
self.target_resolution=target_resolution
self.out_var_name = (
out_var_name if out_var_name is not None else f"{var_name}.nc"
)
......@@ -119,19 +123,22 @@ class Catchment:
self.ds = ds
self.transform = transform
data = self._modify_data(self.ds[var_name])
self.input_ds = self._modify_data(self.ds[var_name])
if self.do_shift and self.is_data_global:
transform = list(self.transform)
transform[2] = 0
self.transform = tuple(transform)
if var == "fdir":
self.add_fdir(data=data.data, ftype=ftype, **kwargs)
self.add_fdir(**kwargs)
elif var == "dem":
self.add_dem(data=data, **kwargs)
self.add_dem(**kwargs)
else:
raise NotImplementedError
if self.target_resolution is not None:
self.upscale(var)
@property
def is_data_global(self):
......@@ -154,12 +161,12 @@ class Catchment:
return np.roll(data, int(len(self.ds.lon) / 2), axis=1)
return data
def add_dem(self, data, **kwargs):
def add_dem(self, **kwargs):
"""
Inits the FlwdirRaster class from dem.
"""
# perform checks
self.elevtn = data.data
self.elevtn = self.input_ds.data
if self._fdir is None:
# Create a flow direction object
logger.info("add_dem: kwargs: ", kwargs)
......@@ -171,17 +178,18 @@ class Catchment:
)
self.get_fdir()
def add_fdir(self, data, ftype, **kwargs):
def add_fdir(self, **kwargs):
"""
Inits the FlwdirRaster class from fdir.
"""
data = self.input_ds.data
# perform check
if self._fdir is None:
mask = np.isnan(data)
if mask.any():
data[mask] = FDIR_FILLVALUE[ftype]
data[mask] = FDIR_FILLVALUE[self.ftype]
data = data.astype(np.uint8)
self._fdir = pyflwdir.from_array(data=data, ftype=ftype, **kwargs)
self._fdir = pyflwdir.from_array(data=data, ftype=self.ftype, **kwargs)
self.get_fdir()
def delineate_basin(self, gauge_coords, stream_order=4):
......@@ -196,22 +204,51 @@ class Catchment:
self.catchment_mask = self.basin > 0
if np.all(~self.catchment_mask):
if stream_order>1:
logger.warning(f'Reducing stream order to {stream_order - 1}')
return self.delineate_basin((gauge_coords[0], gauge_coords[1]), stream_order=stream_order-1)
logger.error("No catchment found for the given coordinates")
if not np.any(np.isnan(self.basin)):
self.basin[np.where(~self.catchment_mask)] = self.VARIABLES["basin"]["_FillValue"]
def upscale(self, var):
"""Upscale flow direction to taget_resolution if that is int multipe of data resolution."""
input_lon = self.input_ds['lon'].data
input_res = round(abs(input_lon[1]-input_lon[0]),6)
if int(self.target_resolution / input_res + 0.5) - (self.target_resolution / input_res) < 1e6:
factor = int(self.target_resolution / input_res + 0.5)
else:
not_int_multiple_msg = f"Upscaling only works if L1 resolution is integer muplipe of L0 resolution but L1 = {self.taget_resolution / input_res:.4f} * L0"
raise ValueError(not_int_multiple_msg)
self._fdir, index = self._fdir.upscale(factor, method='ihu', uparea=None)
self.get_fdir()
if var == 'dem':
lat_size, lon_size = self.input_ds.shape
# Ensure the dimensions are evenly divisible by the factor
if lat_size % factor != 0 or lon_size % factor != 0:
raise ValueError("Data dimensions must be divisible by the upscaling factor of {factor}.")
# Reshape and aggregate data
reshaped = self.input_ds.values.reshape(
lat_size // factor, factor,
lon_size // factor, factor
)
aggregated = reshaped.mean(axis=(1, 3)) # Conservative mean over each block
# Create new DataArray
self.elevtn = aggregated
def get_basins(self):
"""
Performs the calculation of the catchment ids
"""
self.basin = self._fdir.basins()
def get_fdir(self, ftype=None):
def get_fdir(self):
"""
Performs the calculation of the flow direction
"""
self.flwdir = self._fdir.to_array(ftype=ftype or OUTPUT_FTYPE)
self.flwdir = self._fdir.to_array(ftype=self.ftype or OUTPUT_FTYPE)
def get_upstream_area(self):
"""
......@@ -230,6 +267,20 @@ class Catchment:
data[~self._fdir.mask.reshape(data.shape)] = 0
self.uparea_grid = self._fdir.accuflux(data, nodata=0)
@staticmethod
def create_frame(ds, frame=0):
"""If a frame is used this frame is set to no data values as a frame"""
for var in ds.data_vars:
data = ds.variables[var].data[:]
# set bounds to -9999.
data[:frame, :] = 0.
data[-frame:, :] = 0
data[:, :frame] = 0
data[:, -frame:] = 0
ds.variables[var].data[:] = data
return ds
def write(
self,
out_path,
......@@ -238,14 +289,15 @@ class Catchment:
cellsize=None,
cut_by_basin=False,
mask_file=None,
frame=1,
buffer=0
):
data_vars = {}
out_path = pl.Path(out_path)
data = self.basin
if not out_path.is_dir():
out_path.mkdir(parents=True, exist_ok=True)
if cut_by_basin:
lat_slice, lon_slice = self.cut_to_filled_area()
lat_slice, lon_slice = self.cut_to_filled_area(buffer)
else:
lat_slice, lon_slice = slice(84, -56), slice(None)
......@@ -255,13 +307,19 @@ class Catchment:
data[~self.catchment_mask] = self.VARIABLES[var_name]["_FillValue"]
if data is None:
continue
res = self.res
lon = self.ds.lon
lat = self.ds.lat
lon = np.arange(lon.min() + res/2, lon.max()+res/2, res)
lat = np.arange(lat.max()+res/2, lat.min()+res/2, -res)
data_var = xr.Dataset(
{var_name: (["lat", "lon"], self._revert_data(data))},
coords={
"lon": self.ds.lon, # [slice(3555, 3565)],
"lat": self.ds.lat, # [slice(860, 870)],
"lon": lon, # [slice(3555, 3565)],
"lat": lat, # [slice(860, 870)],
},
)
if single_file:
data_vars[var_name] = data_var
else:
......@@ -336,7 +394,14 @@ class Catchment:
logger.debug(f"lat_slice: {lat_slice}, lon_slice: {lon_slice}")
logger.debug(f"ds: {ds}")
mask = ds.basin > 0
if self.ftype == 'ldd':
sink_value = 5
elif self.ftype == 'd8':
sink_value = 0
ds['flwdir'].data[:] = ds.flwdir.where(~((mask) & ((ds.flwdir == np.nan) | (ds.flwdir < 0))), sink_value).data[:]
ds = ds.sel(lat=lat_slice, lon=lon_slice)
ds = self.create_frame(ds, frame)
ds.to_netcdf(
out_path / self.out_var_name,
encoding={
......@@ -350,14 +415,13 @@ class Catchment:
logger.info(f"Basin Id has been written to {out_path / self.out_var_name}")
# use basin_id to create a mask file
if mask_file is not None:
mask = ds.basin > 0
# name the variable mask
mask_file = pl.Path(mask_file)
mask = xr.Dataset({"mask": mask}, coords={"lon": ds.lon, "lat": ds.lat})
mask.to_netcdf(mask_file)
logger.info(f"Mask file has been written to {mask_file}")
def cut_to_filled_area(self):
def cut_to_filled_area(self, buffer=0):
"""Create lat and lon slices to cut the data to the filled area."""
# Find the non-zero elements
cols = np.any(
......@@ -371,10 +435,10 @@ class Catchment:
min_row, max_row = np.where(rows)[0][[0, -1]]
min_col, max_col = np.where(cols)[0][[0, -1]]
# Add a buffer of one cell
min_row = min_row - 1 if min_row > 0 else min_row
min_col = min_col - 1 if min_col > 0 else min_col
max_row = max_row + 1 if max_row < self.catchment_mask.shape[0] else max_row
max_col = max_col + 1 if max_col < self.catchment_mask.shape[1] else max_col
min_row = min_row - buffer if min_row > 0 else min_row
min_col = min_col - buffer if min_col > 0 else min_col
max_row = max_row + buffer if max_row < self.catchment_mask.shape[0] else max_row
max_col = max_col + buffer if max_col < self.catchment_mask.shape[1] else max_col
# Slice the array to extract the filled part
lon_min, lon_max = np.round(self.ds.lon.values[min_col], 3), np.round(
......@@ -441,6 +505,8 @@ def create_catchment(
gauge_coords=None,
coordinate_slices=None,
mask_file=None,
target_resolution=None,
frame = 0
):
logger.info(
......@@ -450,7 +516,7 @@ def create_catchment(
if var not in {"fdir", "dem"}:
raise ValueError(f"Unexpected value for var={var}, must be 'fdir' or 'dem'")
ds = xr.open_dataset(pl.Path(input_file))
# transform
transform = get_transformation_matrix_nc(ds, var_name)
......@@ -480,6 +546,7 @@ def create_catchment(
latlon=latlon,
out_var_name=temp_file2,
do_shift=True,
target_resolution=target_resolution
)
catchments = [global_catchments, global_catchments_shifted]
......@@ -488,7 +555,7 @@ def create_catchment(
c.get_facc()
c.get_grid_area()
# c.get_upstream_area()
c.write(output_path, single_file=True)
c.write(output_path, single_file=True, frame=frame)
# add paths to the temp files
temp_file1 = pl.Path(output_path, "hydro1.nc")
temp_file2 = pl.Path(output_path, "hydro2.nc")
......@@ -513,11 +580,12 @@ def create_catchment(
latlon=latlon,
out_var_name="basin_ids.nc",
do_shift=False,
target_resolution=target_resolution
)
c.get_basins()
c.get_facc()
c.get_grid_area()
c.write(output_path, single_file=True)
c.write(output_path, single_file=True, mask_file=mask_file, frame=frame)
else:
logger.info(f"Creating catchment for gauge coordinates {gauge_coords}")
c = Catchment(
......@@ -529,8 +597,9 @@ def create_catchment(
latlon=latlon,
out_var_name="basin_ids.nc",
do_shift=False,
target_resolution=target_resolution
)
c.delineate_basin(gauge_coords)
c.get_facc()
c.get_grid_area()
c.write(output_path, single_file=True, cut_by_basin=True, mask_file=mask_file)
c.write(output_path, single_file=True, cut_by_basin=True, mask_file=mask_file, frame=frame, buffer=frame+1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment