update_tfs_in_fortran_source.py 33.9 KB
Newer Older
1
2
#!/usr/bin/env python
# encoding: utf-8
3

4
5
6
7
8
9
10
11
12
"""
File Name   : check_transfer_funcs.py
Project Name: MPR
Description : analyzes mpr.nml file for transfer functions (TFs), compares it with source code and adds TFs to code
Author      : Stephan Thober (stephan.thober@ufz.de) and Robert Schweppe (robert.schweppe@ufz.de)
Created     : 2019-09-05 11:46
"""

# IMPORTS
13
import f90nml
14
import pathlib
15
16
from copy import copy
import re
17
from collections import OrderedDict
18
import argparse
19
import string
20
from shutil import copyfile
21
from src_python.pre_proc.mpr_interface import OPTIONS
22

23
# GLOBAL VARIABLES
Robert Schweppe's avatar
Robert Schweppe committed
24
25
FORTRAN_TF_SOURCEFILE = pathlib.Path('mo_mpr_transfer_func.f90')
FORTRAN_DA_SOURCEFILE = pathlib.Path('mo_mpr_data_array.f90')
26
27
MODIFIED_SUFFIX = '.mod'
BACKUP_SUFFIX = '.bak'
28
WORD_BOUNDARIES = string.ascii_letters + string.digits + '_'
29
# maximum line length for Fortran code
30
MAX_LINE_LENGTH = 100
31
EXPERIMENTAL = False
Robert Schweppe's avatar
Robert Schweppe committed
32
33
DEFAULT_CONFIG_FILE = pathlib.Path('mpr.nml')
DEFAULT_SOURCE_FOLDER = pathlib.Path('src')
34
FORTRAN_INDENT = '  '
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
TRANSLATE_DICT = OrderedDict([
    ('+', 'pl'),
    ('-', 'mi'),
    ('**', 'po'),
    ('*', 'ti'),
    ('/', 'di'),
    ('(', 'bs'),
    (')', 'be'),
    ('exp', 'ex'),
    ('log10', 'l1'),
    ('log', 'l2'),
    ('else', 'el'),
    ('if', 'if'),
    ('then', 'th'),
    ('end', 'en'),
    ('where', 'wh'),
    ('<=', 'le'),
    ('<', 'lt'),
    ('>=', 'ge'),
    ('>', 'gt'),
    ('==', 'eq'),
    ('.and.', 'ad'),
    ('.or.', 'or'),
    ('.not.', 'no'),
    ('asin', 'as'),
    ('acos', 'ac'),
    ('atan2', 'au'),
    ('atan', 'at'),
    ('sinh', 'sh'),
    ('cosh', 'ch'),
    ('tanh', 'tx'),
    ('sin', 'si'),
    ('cos', 'co'),
    ('tan', 'ta'),
    ('abs', 'ab'),
    ('max', 'ma'),
    ('min', 'mi'),
    ('sqrt', 'sq'),
    ('^', 'po')
])

# this dict is important as it stores characters that need to be ignored when inserting whitespaces around
# operators as this would break the meaning e.g. do not do "1<=2" -> "1 < =2" but "1 <= 2" for operator "<"
WORD_CHARS_DICT = {
    '<': '=',
    '>': '=',
    '*': '*',
    'log': '1',
    'sin': 'ah',
    'cos': 'ah',
    'tan': 'ah2',
    'atan': '2',
}

90

91
# FUNCTIONS
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def insert_line_break_in_math_string(start_string, insert_string, end_string, *args, index_addon=0):
    trailing_whitespaces = len(start_string) - len(start_string.rstrip(' '))
    leading_whitespaces = len(end_string) - len(end_string.lstrip(' '))
    return_string = start_string.rstrip(' ') + insert_string + end_string.lstrip(' ')
    if args:
        return_string = return_string + ''.join(args)
    index_addon += len(insert_string) - trailing_whitespaces - leading_whitespaces
    return return_string, index_addon


def break_line(line, string_only=False):
    """break a long string (tf name) into multiple parts to be not longer than MAX_LINE_LENGTH"""
    parts = line.splitlines(True)
    for i_part, part in enumerate(parts):
        for break_index in range((len(part) - 1) // MAX_LINE_LENGTH, 0, -1):
            # if at position MAX_LINE_WIDTH, there is a string, split the string
            break_pos = break_index * MAX_LINE_LENGTH
            if part[:break_pos].count('\"') % 2 or string_only:
                part, _ = insert_line_break_in_math_string(
                    part[:break_pos], '"// &\n{}"'.format(FORTRAN_INDENT * 3), part[break_pos:])
            elif part[:break_pos].count('\'') % 2:
                part, _ = insert_line_break_in_math_string(
                    part[:break_pos] + "'// &\n{}'".format(FORTRAN_INDENT * 3) + part[break_pos:])
            # else insert part break at last occurrence of whitespace
            else:
                break_pos = part[:break_index * MAX_LINE_LENGTH].rfind(' ')
                part = part[:break_pos + 1] + '&\n{}'.format(FORTRAN_INDENT * 3) + part[break_pos + 1:]
        parts[i_part] = part
    return ''.join(parts)


def does_contain_pattern(pattern, search_string):
    """check whether if or where clauses as contained in raw tf string"""
    return re.search(r'\b{}\b'.format(pattern), search_string)


128
def get_index_in_string(long_string, part, word_chars=WORD_BOUNDARIES):
129
    """function from Fortran to check if a character sequence text is in a string s without being"""
130
131
    n_p = len(part)
    n_s = len(long_string)
132
133
134
    out_index = -1
    i = 0
    while True:
135
        # find index of the first character of part in string that has not been scanned so far
136
137
138
139
140
141
        if part[0] in long_string[i:n_s]:
            i_add = long_string[i:n_s].index(part[0])
        else:
            i_add = -1
        i = i + i_add
        if i_add == -1 or i + n_p > n_s:
142
143
            # the part cannot be in string as the first char is not even contained or
            # the part cannot be in string starting at i as it would be too long
144
145
            return out_index
        elif long_string[i:i + n_p] == part:
146
            # character matches the part
147
148
149
            is_begin_not_word = True
            # at beginning of string
            if i > 0:
150
                # is the part preceded by a alphanumeric character?
151
152
153
154
155
156
                if long_string[i - 1] in word_chars:
                    is_begin_not_word = False
                # hack so positive number is not found in negative number
                if long_string[i - 1] == '-' and part.startswith(tuple(string.digits)):
                    is_begin_not_word = False
            if is_begin_not_word:
157
                # is the part succeeded by a alphanumeric character?
158
                if i + n_p < n_s and long_string[i + n_p] in word_chars:
159
                    # part boundary end is violated, continue
160
161
                    i += 1
                else:
162
                    # index is found and part boundaries are checked
163
164
                    return i
            else:
165
                # part boundary start is violated, continue
166
167
168
169
170
171
                i += 1
        else:
            # word does not match, continue
            i += 1


172
173
174
175
176
def replace_in_string(search_string, string_pattern, replacement, *args, **kwargs):
    """
    replaces strings in a search string considering word boundaries, e.g.:
    _replace_in_string('s ss sss', 'ss', 's') -> 's s sss'"""
    index = get_index_in_string(search_string, string_pattern, *args, **kwargs)
177
    while index >= 0:
178
179
180
181
182
183
184
        search_string = search_string[:index] + replacement + search_string[index + len(string_pattern):]
        # continue the search from after the replacement
        covered_length = index+len(replacement)
        index = get_index_in_string(search_string[covered_length:], string_pattern, *args, **kwargs)
        if index >= 0:
            index += covered_length
    return search_string
185
186
187
188
189
190


# CLASSES
class TF(object):
    """class handling all string operations on a transfer function"""
    MATH_PREFIX_STRING = 'func_result(:) = '
Robert Schweppe's avatar
Robert Schweppe committed
191

192

193
    def __init__(self, raw_tf_string, predictors, is_test_mode=True, global_params=None, tfs=None):
194
195
        """init the TF class by providing the raw string information from the namelist
         with transfer_func, from_data_arrays and also a flag, the list of the global parameters and tfs"""
196
        # set args
197
198
        self.tfs = tfs or {}
        self.global_params = global_params or {}
199
        self.raw_tf_string = raw_tf_string
200
        self.predictors = predictors
201
        self.is_test_mode = is_test_mode
Robert Schweppe's avatar
Robert Schweppe committed
202

203
        # set main properties
Robert Schweppe's avatar
Robert Schweppe committed
204
        self.index_name = ''
205
        self.processed_tf_string = copy(self.raw_tf_string)
206
        self.translated_name = ''
207
        self.math_string = ''
208

209
210
    @property
    def is_contained(self):
211
        """is the translated tf name already contained in source code"""
212
        return '"{}"'.format(self.translated_name) in self.tfs
213

214
215
    @property
    def n_params(self):
216
        """get the number of parameters needed for tf"""
217
        return len(set(re.findall(r'p[0-9]+', self.translated_name)))
218

219
220
    @property
    def _max_index(self):
221
        """what is the next free index number for the transfer_function name"""
222
223

        return max([int(key.split('_')[2]) for key in self.tfs.values()] + [0]) + 1
224

225
    def translate_func(self):
226
227
228
229
        """
        translate the raw transfer function name (from namelist) to
        Fortran code, unique key and running index name
        """
230
        # initialize
231
        tf_string = copy(self.raw_tf_string)
232

233
234
        for ii, predictor in enumerate(self.predictors):
            # replace all occurrences of predictor as a whole word
235
            tf_string = replace_in_string(tf_string, predictor, 'x{}'.format(ii + 1))
236

237
238
        indices = []
        params = []
239
        # check for existence of parameters as word, get their index of first occurrence
240
        for param in self.global_params.keys():
241
            index = get_index_in_string(tf_string, param)
242
243
            if index >= 0:
                indices.append(index)
244
                params.append(param)
245
        # sort params according to indices
246
        params = [params[i] for i in sorted(range(len(indices)), key=lambda k: indices[k])]
Robert Schweppe's avatar
Robert Schweppe committed
247

248
249
        # replace parameters in right order, in our working copy as well as in the raw string
        for i_param, param in enumerate(params, 1):
250
            tf_string = replace_in_string(tf_string, param, 'p{}'.format(i_param))
251

252
253
        # temporary set the name so that math string creation works
        self.translated_name = tf_string
254
255

        if self.is_contained:
256
            self.math_string = ''
257
258
            self.index_name = self.tfs['"{}"'.format(self.translated_name)]
        else:
259
260
261
262
            self.math_string = self._set_math_string(tf_string)

            # work further to get tf key string:
            # ... replace operators
263
            for key, val in TRANSLATE_DICT.items():
264
265
266
267
268
269
270
271
272
                tf_string = tf_string.replace(key, '_{}_'.format(val))

            # ... eliminate blanks
            tf_string = tf_string.replace(' ', '')

            # ... remove underscores at beginning and end and multiple underscores
            self.translated_name = re.sub(r'[_]{2,}', r'_', tf_string).strip('_')

            # set the index name
273
            self.index_name = 'transfer_function_{}'.format(self._max_index)
274

275
    def _set_math_string(self, prepared_string):
276
        """
277
        create mathematical expression for that string, e.g.:
278
279
280
281
282
        "func_result(:) = param(1) + param(2) * x(1)%data_p(:) + param(3) * x(2)%data_p(:)"

        Returns
        -------
        """
283
        math_string = copy(prepared_string)
284
        remainder = copy(prepared_string)
285
        for key in TRANSLATE_DICT.keys():
286
287
288
            math_string = replace_in_string(math_string,
                                            key,
                                            ' {} '.format(key),
289
290
                                            word_chars=WORD_CHARS_DICT.get(key, ''))
            remainder = replace_in_string(remainder, key, '', word_chars=WORD_CHARS_DICT.get(key, ''))
291
292
        # replace the parameters and data arrays by the Fortran syntax
        for i_param in range(self.n_params, 0, -1):
293
            math_string = math_string.replace('p{}'.format(i_param), 'param({})'.format(i_param))
294
            remainder = remainder.replace('p{}'.format(i_param), '')
295
        for i_pred in range(len(self.predictors), 0, -1):
296
            math_string = math_string.replace('x{}'.format(i_pred), 'x({})%data_p(:)'.format(i_pred))
297
            remainder = remainder.replace('x{}'.format(i_pred), '')
298

299
        # polish some user-defined whitespace hoipolloi
300
301
302
303
304
305
306
307
        math_string = re.sub(r'[ ]{2,}', r' ', math_string).strip(' ')
        remainder = re.sub(r'[ ]{2,}', r' ', remainder).strip(' ')
        if remainder:
            raise Exception('Could not successfully parse the following characters : ' +
                            ', '.join(['"{}"'.format(_) for _ in remainder.split(' ')]) +
                            ' in transfer function "{}".'.format(self.raw_tf_string))
        # handle the much more complex where or if clauses
        math_string = self._handle_if_or_where_clause(math_string)
308
309
310

        # insert line breaks at appropriate places in the function if became very long
        if (len(math_string) - 1) > MAX_LINE_LENGTH:
311
            math_string = break_line(math_string)
312
        return math_string
313

314
315
316
317
318
319
320
321
322
323
    def _handle_if_or_where_clause(self, math_string):
        """create the math string for if and where clauses"""
        triggers = ['if', 'where']
        if re.match(r'|'.join([r'\b{}\b'.format(trigger) for trigger in triggers]), math_string) is None:
            # we prepend the fun_result = string
            math_string = self.MATH_PREFIX_STRING + math_string
        else:
            # now work on special commands like if-clauses and where-clauses
            for trigger in triggers:
                math_string = self._format_trigger_in_math_string(math_string, trigger)
324

325
        return math_string
326

327
    def _format_trigger_in_math_string(self, math_string, trigger):
328
329
        char1 = '('
        char2 = ')'
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        # if (...) insert_strig_dict1 ... else ...
        # where (...) insert_strig_dict1 ... else where ...
        insert_strig_dict1 = {'if': 'then', 'where': ''}
        # if (...) then ... else insert_strig_dict2 ...
        # where (...) ... else insert_strig_dict2 ...
        insert_strig_dict2 = {'if': '', 'where': ' where'}
        # special case for where statements without else - set it to no_data in all cases
        insert_strig_dict3 = {'if': '', 'where': '{}{}{}{}{}{}'.format(
            '\n',
            FORTRAN_INDENT * 2,
            'else where',
            '\n',
            FORTRAN_INDENT * 3,
            self.MATH_PREFIX_STRING + 'nodata_dp')}
        end_string = '\n{}end {}'.format(FORTRAN_INDENT * 2, trigger)

        index = 0
        index_addon = 0
        contains_pattern = does_contain_pattern(trigger, math_string)
        if contains_pattern is not None:
            index = contains_pattern.end()
        while contains_pattern is not None:
            # open_parenthesis_counter of levels of nestedness ( char1 increases it, char2 decreases it)
            open_parenthesis_counter = 0
            # whether a pair of parenthesis is contained
            contained_parenthesis = False
            # loop over each character from trigger to end of math_string
            for index in range(index, len(math_string) + 1):
                # continue statements are omitted in each if clause, only execute something in if-clauses
359
                if open_parenthesis_counter == 0 and contained_parenthesis:
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
                    if trigger == 'if' and not math_string[index + index_addon:].startswith('then'):
                        raise Exception('The transfer function is not valid, a "then" needs to follow the',
                                        'logical expression of an if-statement for transfer function:',
                                        self.raw_tf_string)
                    # this code gets inserted, nice linebreak and indentation
                    insert_string = '{}{}{}'.format(
                        '{}\n'.format(insert_strig_dict1[trigger]),
                        FORTRAN_INDENT * 3,
                        self.MATH_PREFIX_STRING)
                    math_string, index_addon = insert_line_break_in_math_string(
                        math_string[:index + index_addon],
                        insert_string,
                        math_string[index + index_addon + len(insert_strig_dict1[trigger]):],
                        index_addon=index_addon,
                    )

                    # continue looking for else patterns
                    contains_pattern = does_contain_pattern(r'else[ ]+' + trigger, math_string[index + index_addon:])
                    contains_else = does_contain_pattern(r'else', math_string[index + index_addon:])

                    if contains_pattern:
                        insert_string = '{}{}{}'.format(
                            '\n',
                            FORTRAN_INDENT * 2,
                            'else {} '.format(trigger),
                        )
                        math_string, index_addon = insert_line_break_in_math_string(
                            math_string[:index + index_addon + contains_pattern.start()],
                            insert_string,
                            math_string[index + index_addon + contains_pattern.end():],
                            index_addon=index_addon + contains_pattern.start(),
                        )
392
                        contained_parenthesis = False
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
                    elif contains_else:
                        insert_string = '{}{}{}{}{}{}'.format(
                            '\n',
                            FORTRAN_INDENT * 2,
                            'else{}'.format(insert_strig_dict2[trigger]),
                            '\n',
                            FORTRAN_INDENT * 3,
                            self.MATH_PREFIX_STRING
                        )
                        math_string, index_addon = insert_line_break_in_math_string(
                            math_string[:index + index_addon + contains_else.start()],
                            insert_string,
                            math_string[index + index_addon + contains_else.end():],
                            end_string,
                            index_addon=index_addon + contains_else.start(),
                        )
                        contains_pattern = None
                        break
                    else:
                        # currently we only support one if or where clause per TF, so if no else is found:
                        math_string = '{}{}{}'.format(math_string, insert_strig_dict3[trigger], end_string)
                        # index_addon += len(insert_string)
                        contains_pattern = None
                        break
417
418
419
420
421
422
423
424
425
426
427
428
                elif math_string[index + index_addon] == ' ':
                    continue
                elif math_string[index + index_addon] == char2:
                    open_parenthesis_counter -= 1
                elif math_string[index + index_addon] == char1:
                    open_parenthesis_counter += 1
                    contained_parenthesis = True
                elif not contained_parenthesis:
                    raise Exception('The transfer function is not valid, a "' + char1 +
                                    '" needs to follow a "' + trigger +
                                    '"-statement for transfer function: ' + self.raw_tf_string,
                                    '\nException occurred at character ', str(index + index_addon + 1))
429
430

        return math_string
431

Robert Schweppe's avatar
Robert Schweppe committed
432
    def insert_index(self, *args):
433
434
435
436
        """
        helper function to format the values set for indices property
        in TransferFunctionTable in mp_mpr_transfer_func.f90
        """
437
        return '{}_i4'.format(self.index_name.split('_')[-1], *args)
438

Robert Schweppe's avatar
Robert Schweppe committed
439
    def insert_name(self, *args):
440
441
442
443
        """
        helper function to format the values set for names property
        in TransferFunctionTable in mp_mpr_transfer_func.f90
        """
444
        return '"{}"'.format(break_line(self.translated_name, string_only=True), *args)
445
446
447


class SourceCode(object):
448
    """parent class for custom Fortran source code parts"""
Robert Schweppe's avatar
Robert Schweppe committed
449

450
451
452
453
454
    def __init__(self, filepath):
        self.source = self.read_fortran_tf_source(filepath)
        self.tfs = []

    def _retrieve_values(self, key, chars_to_delete=None):
455
        """retrieve the values of a 1d array parameter property"""
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        chars_to_delete = chars_to_delete or []
        # set the pattern we look for in the string
        start_pattern = '{} = ['.format(key)
        # get index of key
        start_index = self.source.find(start_pattern)
        end_index = self.source[start_index + len(start_pattern):].find(']')
        # select the part of the source we are actually interested in
        part = self.source[start_index + len(start_pattern):start_index + len(start_pattern) + end_index]
        # remove the type specification
        if '::' in part:
            part = part.split('::')[-1]
        # remove the unneeded Fortran syntax
        for chars in chars_to_delete:
            part = part.replace(chars, '')

        values = [item.strip(' ') for item in part.split(',')]
        return values

Robert Schweppe's avatar
Robert Schweppe committed
474
475
    @staticmethod
    def read_fortran_tf_source(filepath):
476
        """read the Fortran source file and parse it to a string"""
477
478
479
480
        with open(filepath) as file:
            source = file.read()
        return source

481
    def add_tf(self, tf):
482
        """add a TF to the repo"""
483
484
485
        self.tfs.append(tf)

    @staticmethod
486
    def _paste_lines(t_lines, pos, p_lines, leading_blanks=''):
487
        """add a special line marking the inserted code"""
488
489
        for paste in p_lines:
            t_lines.insert(pos + 1, paste)
490
        t_lines.insert(pos + 1, '{}! >>> inserted automatically by Python script'.format(leading_blanks))
491
492
        return t_lines

Robert Schweppe's avatar
Robert Schweppe committed
493

494
class TFSource(SourceCode):
495
    """parent class for custom Fortran source code parts"""
Robert Schweppe's avatar
Robert Schweppe committed
496

497
498
499
500
501
502
    def __init__(self, *args, **kwargs):
        super(TFSource, self).__init__(*args, **kwargs)
        # retrieve the list of tf names as dict
        self.tf_names = self.get_tf_names()

    def get_tf_names(self):
503
        """retrieve names and indices of transfer_function_names and put them in a dict"""
504
505
506
507
508
509
510
        # modify the lookup table
        names = self._retrieve_values('names', ['&', '\n', ' ', '//'])
        indices = self._retrieve_values('indices', ['&', '\n', ' ', '//', '_i4'])

        return dict(zip(names, ['transfer_function_{}'.format(item) for item in indices]))

    def get_source(self):
511
        """update the source code and return it as a list of lines (string)"""
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        tf_source = self.source.splitlines()
        # modify the lookup table
        for key in ['indices', 'names']:
            # start with the dimension length
            ii = [tf_source.index(ll) for ll in tf_source if '{} = ['.format(key) in ll][0]
            old_index = re.search(r'dimension\(([0-9]+)\)', tf_source[ii]).group(1)
            tf_source[ii] = tf_source[ii].replace('dimension({})'.format(old_index),
                                                  'dimension({})'.format(int(old_index) + len(self.tfs)))

            # now add the new values to the line
            while True:
                if ']' in tf_source[ii]:
                    # we are at last line of values, we need to insert the values here
                    for tf in self.tfs:
Robert Schweppe's avatar
Robert Schweppe committed
526
                        format_value = {'indices': tf.insert_index, 'names': tf.insert_name}[key]
527
528
529
                        # if new item would increase line length over max, insert line break
                        if len(tf_source[ii]) + len(format_value()) > MAX_LINE_LENGTH:
                            tf_source[ii] = tf_source[ii].replace(']', ', &')
530
                            tf_source.insert(ii + 1, ' ' * 6 + format_value() + ']')
531
532
533
534
535
                            ii += 1
                        # else simply append
                        else:
                            tf_source[ii] = tf_source[ii].replace(']', ', ' + format_value() + ']')
                    break
536
                else:
537
                    ii += 1
538

539
540
541
542
        # add function in mo_mpr_transfer_function
        # insert function declaration
        ii = [tf_source.index(ll) for ll in tf_source if 'private' in ll][0]
        for tf in self.tfs[::-1]:
543
            tf_source.insert(ii + 2, '  public :: {}'.format(tf.index_name))
544

545
546
547
548
549
        # insert function at the end
        for tf in self.tfs[::-1]:
            ii = [tf_source.index(ll) for ll in tf_source if 'end module mo_mpr_transfer_func' in ll][0] - 1
            # insert upside down
            func_string_list = [
550
551
552
                '  end function {}'.format(tf.index_name),
                '    {}'.format(tf.math_string),
                '    ',
553
                '    real(dp), dimension(size(x(1)%data_p, kind=ix)) :: func_result',
554
                '    real(dp), dimension(:), intent(in) :: param',
555
                '    type(InputFieldContainer), dimension(:), intent(in) :: x',
556
                '  pure function {}(x, param) result(func_result)'.format(tf.index_name),
557
                '  ! ----------------------------------------------------------------------------------------',
558
            ]
559
            tf_source = self._paste_lines(tf_source, ii, func_string_list, '  ')
560
561
562

        return tf_source

Robert Schweppe's avatar
Robert Schweppe committed
563

564
565
566
567
568
class DASource(SourceCode):
    def __init__(self, *args, **kwargs):
        super(DASource, self).__init__(*args, **kwargs)

    def get_source(self):
569
        """update the source code and return it as a list of lines (string)"""
570
571
572
573
574
575
576
577
578
579
        da_source = self.source.splitlines()

        start_scan = False
        for ii in range(len(da_source)):
            line = da_source[ii].replace(' ', '').lower()[:-1]
            if 'subroutinecall_transfer_func' in line and not start_scan:
                start_scan = True
            if start_scan and 'case' in line:
                for tf in self.tfs[::-1]:
                    set_func_list = [
580
581
                        '      data = {}(inputFieldContainers, self%globalParameters)'.format(tf.index_name),
                        '      call self%check_transfer_func_args({}_i4, {}_i4, size(inputFieldContainers))'.format(
582
                            len(tf.predictors), tf.n_params),
583
584
                        '    case(\'{}\')'.format(tf.index_name)]
                    da_source = self._paste_lines(da_source, ii, set_func_list, '    ')
585
586
587
588
589
                break

        # insert use statement
        use_lines = [da_source.index(ll) for ll in da_source if 'implicit none' in ll][0]
        for tf in self.tfs[::-1]:
590
            da_source.insert(use_lines - 1, '  use mo_mpr_transfer_func, only: {}'.format(tf.index_name))
591
592

        return da_source
593
594


595
class TFConverter(object):
596
    """wrapper class for handling the modification of Fortran source files to incorporate new transfer functions"""
Robert Schweppe's avatar
Robert Schweppe committed
597

598
599
600
601
602
603
    def __init__(self):
        # all the attributes
        self.commandLineArgs = None
        self.do_create_backup = True

        # the (parsed) Fortran source
Robert Schweppe's avatar
Robert Schweppe committed
604
605
        self.tf_source = None
        self.da_source = None
606
607
608

        self.mpr_nml = {}
        self.mpr_global_parameter_nml = {}
609
        # final dict storing the Parameters
610
611
612
        self.global_params = {}

    def parse_args(self):
613
        """parse command line arguments"""
614
615
616
617
618
        parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                         description='''Preprocessor script for MPR.

            author: Stephan Thober, Robert Schweppe
            created: Mar 2018''')
619
        parser.add_argument('-c', '--config_file', action='store', type=pathlib.Path,
620
621
                            default=DEFAULT_CONFIG_FILE, dest='config_file', metavar='config_file',
                            help="path to config file for MPR (Default: {})".format(DEFAULT_CONFIG_FILE))
622
        parser.add_argument('-p', '--parameter_file', action='store', type=pathlib.Path,
623
624
                            dest='param_file',
                            help="path to config file with extra parameters for MPR")
625
        parser.add_argument('-s', '--source_folder', action='store', type=pathlib.Path,
626
627
628
                            default=DEFAULT_SOURCE_FOLDER, dest='src_folder',
                            help="path to source code for MPR (Default: {})".format(DEFAULT_SOURCE_FOLDER))
        parser.add_argument('-t', '--test_mode', action='store_true',
629
630
631
                            default=EXPERIMENTAL, dest='is_test_mode',
                            help="whether to write to temporary files with '{}' suffix (Default: {})".format(
                                MODIFIED_SUFFIX, EXPERIMENTAL))
632
633
634
        parser.add_argument('--clean', action='store_true',
                            dest='from_bak',
                            help="base Fortran source on '{}' files".format(
635
                                BACKUP_SUFFIX))
636
637
638

        self.commandLineArgs = parser.parse_args()

Robert Schweppe's avatar
Robert Schweppe committed
639
640
    @staticmethod
    def check_for_backup(files_to_check_against=None, suffix=''):
641
        """checks whether a set of file exists and performs some logic"""
642
643
        if files_to_check_against is None:
            return True
644
        else:
645
646
            # all files do not exist -> True
            # at least one file exists -> False
647
            return all([not file.with_suffix(file.suffix + suffix).exists() for file in files_to_check_against])
648
649

    def read_source_files(self):
650
        """reads all Fortran source files to be modified and also all configuration files"""
651
652
653
654

        da_file = pathlib.Path(self.commandLineArgs.src_folder, FORTRAN_DA_SOURCEFILE)
        transfer_func_file = pathlib.Path(self.commandLineArgs.src_folder, FORTRAN_TF_SOURCEFILE)

655
        # check if .bak files exist, if not create them
656
657
        do_create_backup = self.check_for_backup([da_file, transfer_func_file], suffix=BACKUP_SUFFIX)
        if do_create_backup:
658
            print('creating backup files at {}'.format(pathlib.Path(da_file.parent, '*' + BACKUP_SUFFIX)))
659
660
            copyfile(da_file, da_file.with_suffix(da_file.suffix + BACKUP_SUFFIX))
            copyfile(transfer_func_file, transfer_func_file.with_suffix(transfer_func_file.suffix + BACKUP_SUFFIX))
661
662

        # read Fortran source files
663
664
665
666
667
668
669
        for filepath, target, target_type in zip(
                [da_file, transfer_func_file],
                ['da_source', 'tf_source'],
                [DASource, TFSource],
        ):
            mod_path = filepath.with_suffix(filepath.suffix + MODIFIED_SUFFIX)
            if self.commandLineArgs.is_test_mode and mod_path.exists():
670
                print('reading Fortran source file from {}'.format(mod_path))
671
                setattr(self, target, target_type(mod_path))
672
673
674
675
            elif not do_create_backup and self.commandLineArgs.from_bak:
                bak_path = filepath.with_suffix(filepath.suffix + BACKUP_SUFFIX)
                print('reading Fortran source file from {}'.format(bak_path))
                setattr(self, target, target_type(bak_path))
676
            else:
677
                print('reading Fortran source file from {}'.format(filepath))
678
                setattr(self, target, target_type(filepath))
679
680

        # read namelists
681
682
683
684
        targets = ['mpr_nml', 'mpr_global_parameter_nml']
        for filepath, target in zip([self.commandLineArgs.config_file, self.commandLineArgs.param_file], targets):
            if filepath is not None:
                setattr(self, target, self._read_namelist_source(filepath))
685
686
687

        self.global_params = self._get_parameters()

Robert Schweppe's avatar
Robert Schweppe committed
688
689
    @staticmethod
    def _read_namelist_source(filepath):
690
        """read content from *.nml files and return a dict-like Namelist instance"""
691
692
693
694
695
696
697
698
699
700
701
702

        if not filepath.exists():
            return {}

        parser = f90nml.Parser()
        parser.global_start_index = 1
        return parser.read(filepath)

    def _get_parameters(self):
        """join the parameter names and values from the mpr config dicts to a global dict of parameters"""
        global_params = {}
        for _dict in self.mpr_nml, self.mpr_global_parameter_nml:
703
704
            if _dict and list(OPTIONS.keys())[18][0] in _dict:
                # create a dict with {parameter_names: parameter_values}
705
706
707
708
709
710
                global_params.update({str(k): v for k, v in zip(_dict[list(OPTIONS.keys())[18]],
                                                                _dict[list(OPTIONS.keys())[19]]) if k is not None})
        for key in global_params.keys():
            if re.match(r'[xp]{1}[\d]+', key):
                raise Exception('Please do not use parameter names starting with "x" or "p" and followed by a number. '
                                'You provided: ' + key)
711
712
713
714
        return global_params

    def parse_tfs(self):
        """parse the TFs from the Fortran source code and modify the source code according to configuration"""
715
        tfs = self.tf_source.tf_names
716
        # loop over effective params
717
        transfer_funcs = self.mpr_nml[list(OPTIONS.keys())[23]]
718
        predictors_key = list(OPTIONS.keys())[22]
719
        for ii, name in enumerate(self.mpr_nml[list(OPTIONS.keys())[20]]):
720
            if ii + 1 > len(transfer_funcs):
721
722
723
                transfer_func = None
            else:
                transfer_func = transfer_funcs[ii]
724

725
            if name is None or transfer_func is None:
726
727
728
                # no transfer function defined for this effective parameter
                continue

729
730
731
732
            if predictors_key[-1] in self.mpr_nml[predictors_key[0]]:
                predictors = [_ for _ in self.mpr_nml[predictors_key][ii] if _ is not None]
            else:
                predictors = []
733
734
735

            # handle the case, where there is a transfer function, but no from_data_arrays, then use the name ("self")
            if not predictors:
736
                predictors = [name]
737
738
739
740
741

            # initialize transfer function object
            tf = TF(
                transfer_func,
                predictors,
742
                self.commandLineArgs.is_test_mode,
743
                self.global_params,
744
                tfs
745
            )
746
747
748
749
750
751
752
753
754
755
            # generate the unique tf name and replace predictors and parameters in raw string
            tf.translate_func()

            # check if we already have the tf registered (in self.tfs)
            if not tf.is_contained:
                # add the TF to the source code
                self.tf_source.add_tf(tf)
                self.da_source.add_tf(tf)
                # register the new tf in the dict so future tfs know about the existing ones
                tfs['"{}"'.format(tf.translated_name)] = tf.index_name
756

757
758
        print('added {} new transfer functions to code:'.format(len(self.tf_source.tfs)))
        for tf in self.tf_source.tfs:
759
            print(f'-> "{tf.raw_tf_string}"')
760

761
762
    def write(self):
        """write the Fortran source files"""
763
764
765
766
767
        for filename, sourcecode in zip([FORTRAN_DA_SOURCEFILE, FORTRAN_TF_SOURCEFILE],
                                        [self.da_source, self.tf_source]):
            filepath = pathlib.Path(self.commandLineArgs.src_folder, filename)
            if self.commandLineArgs.is_test_mode:
                filepath = filepath.with_suffix(filepath.suffix + MODIFIED_SUFFIX)
768
            print('writing Fortran source file to {}'.format(filepath))
769
            with open(filepath, 'w') as file:
770
                for line in sourcecode.get_source():
771
                    file.write('{}\n'.format(line))
772
773
774
775
776
777
778
779


if __name__ == '__main__':
    converter = TFConverter()
    converter.parse_args()
    converter.read_source_files()
    converter.parse_tfs()
    converter.write()
780

781
    print('Done!')