Skip to content

Commit 7a4e85c

Browse files
authored
Merge pull request #169 from bugoverdose/feat/task_agnostic_model
Setup task-agnostic model configuration
2 parents e656da8 + 485d3fc commit 7a4e85c

File tree

4 files changed

+90
-75
lines changed

4 files changed

+90
-75
lines changed

R/R/hBayesDM_model.R

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
#'
103103
#' @return A specific hBayesDM model function.
104104

105-
hBayesDM_model <- function(task_name,
105+
hBayesDM_model <- function(task_name = "",
106106
model_name,
107107
model_type = "",
108108
data_columns,
@@ -148,10 +148,20 @@ hBayesDM_model <- function(task_name,
148148
} else if (length(data) == 1 && is.character(data)) {
149149
# Set
150150
if (data == "example") {
151-
example_data <-
152-
ifelse(model_type == "",
153-
paste0(task_name, "_", "exampleData.txt"),
154-
paste0(task_name, "_", model_type, "_", "exampleData.txt"))
151+
model_meta <- c()
152+
if (task_name != "") {
153+
model_meta <- c(model_meta, task_name)
154+
} else {
155+
model_meta <- c(model_meta, model_name)
156+
}
157+
if (model_type != "") {
158+
model_meta <- c(model_meta, model_type)
159+
}
160+
if (length(model_meta) == 0) {
161+
stop("invalid model configuration")
162+
}
163+
example_data <- paste0(paste(model_meta, collapse = "_"), "_exampleData.txt")
164+
155165
datafile <- system.file("extdata", example_data, package = "hBayesDM")
156166

157167
if (!file.exists(datafile)) {
@@ -282,11 +292,20 @@ hBayesDM_model <- function(task_name,
282292
}
283293

284294
# Full name of model
285-
if (model_type == "") {
286-
model <- paste0(task_name, "_", model_name)
287-
} else {
288-
model <- paste0(task_name, "_", model_name, "_", model_type)
295+
model_meta <- c()
296+
if (task_name != "") {
297+
model_meta <- c(model_meta, task_name)
298+
}
299+
if (model_name != "") {
300+
model_meta <- c(model_meta, model_name)
301+
}
302+
if (model_type != "") {
303+
model_meta <- c(model_meta, model_type)
304+
}
305+
if (length(model_meta) == 0) {
306+
stop("invalid model configuration")
289307
}
308+
model <- paste(model_meta, collapse = "_")
290309

291310
# Set number of cores for parallel computing
292311
if (ncore <= 1) {

commons/convert-to-py.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
except ImportError:
1818
from yaml import Loader, Dumper
1919

20+
from utils import model_info, preprocess_func_prefix
2021

2122
def represent_none(self, _):
2223
return self.represent_scalar('tag:yaml.org,2002:null', '')
@@ -109,7 +110,6 @@ def message_additional_args(additional_args):
109110
else:
110111
return 'Not used for this model.'
111112

112-
113113
def main(info_fn):
114114
# Check if file exists
115115
if not info_fn.exists():
@@ -119,25 +119,13 @@ def main(info_fn):
119119
# Load model information
120120
with open(info_fn, 'r') as f:
121121
info = ordered_load(f, Loader=Loader)
122-
123-
# Model full name (Snake-case)
124-
model_function = [info['task_name']['code'], info['model_name']['code']]
125-
if info['model_type']['code']:
126-
model_function.append(info['model_type']['code'])
127-
model_function = '_'.join(model_function)
122+
123+
model_function, task_name_code, _, model_type_code = model_info(info)
128124

129125
# Model class name (Pascal-case)
130126
class_name = model_function.title().replace('_', '')
131127

132-
# Prefix to preprocess_func
133-
prefix_preprocess_func = info['task_name']['code']
134-
if info['model_type']['code']:
135-
prefix_preprocess_func += '_' + info['model_type']['code']
136-
137-
# Model type code
138-
model_type_code = info['model_type'].get('code')
139-
if model_type_code is None:
140-
model_type_code = ''
128+
prefix_preprocess_func = preprocess_func_prefix(info)
141129

142130
# Preprocess citations
143131
def shortify(cite: str) -> str:
@@ -157,12 +145,16 @@ def shortify(cite: str) -> str:
157145
(shortify(cite), cite) for cite in info['model_name']['cite'])
158146
else:
159147
model_cite = {}
148+
149+
task_name_desc = info.get('task_name', {}).get('desc')
150+
model_name_desc = info.get('model_name', {}).get('desc')
151+
model_type_desc = info.get('model_type', {}).get('desc')
160152

161153
# Read template for docstring
162154
with open(TEMPLATE_DOCS, 'r') as f:
163155
docstring_template = f.read().format(
164156
model_function=model_function,
165-
task_name=info['task_name']['desc'],
157+
task_name=task_name_desc if task_name_desc is not None else "",
166158
task_cite_short=format_list(
167159
task_cite,
168160
fmt='[{}]_',
@@ -171,7 +163,7 @@ def shortify(cite: str) -> str:
171163
task_cite,
172164
fmt='.. [{}] {}',
173165
sep='\n '),
174-
model_name=info['model_name']['desc'],
166+
model_name=model_name_desc if model_name_desc is not None else "",
175167
model_cite_short=format_list(
176168
model_cite,
177169
fmt='[{}]_',
@@ -181,7 +173,7 @@ def shortify(cite: str) -> str:
181173
if k not in task_cite),
182174
fmt='.. [{}] {}',
183175
sep='\n '),
184-
model_type=info['model_type']['desc'],
176+
model_type=model_type_desc if model_type_desc is not None else "",
185177
notes=format_list(
186178
info.get('notes') if info.get('notes') else [],
187179
fmt='.. note::\n {}',
@@ -221,7 +213,7 @@ def shortify(cite: str) -> str:
221213
model_function=model_function,
222214
class_name=class_name,
223215
prefix_preprocess_func=prefix_preprocess_func,
224-
task_name=info['task_name']['code'],
216+
task_name=task_name_code if task_name_code is not None else "",
225217
model_name=info['model_name']['code'],
226218
model_type=model_type_code,
227219
data_columns=format_list(
@@ -284,13 +276,7 @@ def generate_init(info_fns):
284276
with open(info_fn, 'r') as f:
285277
info = ordered_load(f, Loader=Loader)
286278

287-
# Model full name (Snake-case)
288-
model_function = [info['task_name']['code'],
289-
info['model_name']['code']]
290-
if info['model_type']['code']:
291-
model_function.append(info['model_type']['code'])
292-
model_function = '_'.join(model_function)
293-
279+
model_function, _, _, _ = model_info(info)
294280
mfs.append(model_function)
295281

296282
lines = []

commons/convert-to-r.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
except ImportError:
1919
from yaml import Dumper, Loader
2020

21+
from utils import model_info, preprocess_func_prefix, extract_or_empty_string
2122

2223
def represent_none(self, _):
2324
return self.represent_scalar('tag:yaml.org,2002:null', '')
@@ -114,11 +115,7 @@ def format_references_block(cites_formatted):
114115

115116

116117
def generate_docs(info):
117-
# Model full name (Snake-case)
118-
model_function = [info['task_name']['code'], info['model_name']['code']]
119-
if info['model_type']['code']:
120-
model_function.append(info['model_type']['code'])
121-
model_function = '_'.join(model_function)
118+
model_function, _, _, _ = model_info(info)
122119

123120
# Citations
124121
if info['task_name'].get('cite'):
@@ -210,13 +207,13 @@ def generate_docs(info):
210207

211208
docs = docs_template % dict(
212209
model_function=model_function,
213-
task_name=info['task_name']['desc'],
214-
task_code=info['task_name']['code'],
210+
task_name=extract_or_empty_string(info, 'task_name', 'desc'),
211+
task_code=extract_or_empty_string(info, 'task_name', 'code'),
215212
task_parencite=task_parencite,
216-
model_name=info['model_name']['desc'],
217-
model_code=info['model_name']['code'],
213+
model_name=extract_or_empty_string(info, 'model_name', 'desc'),
214+
model_code=extract_or_empty_string(info, 'model_name', 'code'),
218215
model_parencite=model_parencite,
219-
model_type=info['model_type']['desc'],
216+
model_type=extract_or_empty_string(info, 'model_type', 'desc'),
220217
notes=notes,
221218
contributor=contributors,
222219
data_columns=data_columns,
@@ -234,22 +231,8 @@ def generate_docs(info):
234231

235232

236233
def generate_code(info):
237-
# Model full name (Snake-case)
238-
model_function = [info['task_name']['code'], info['model_name']['code']]
239-
if info['model_type']['code']:
240-
model_function.append(info['model_type']['code'])
241-
model_function = '_'.join(model_function)
242-
243-
# Prefix to preprocess_func
244-
prefix_preprocess_func = info['task_name']['code']
245-
if info['model_type']['code']:
246-
prefix_preprocess_func += '_' + info['model_type']['code']
247-
preprocess_func = prefix_preprocess_func + '_preprocess_func'
248-
249-
# Model type code
250-
model_type_code = info['model_type'].get('code')
251-
if model_type_code is None:
252-
model_type_code = ''
234+
model_function, _, _, model_type_code = model_info(info)
235+
preprocess_func = preprocess_func_prefix(info) + "_preprocess_func"
253236

254237
# Data columns
255238
data_columns = ', '.join([
@@ -295,8 +278,8 @@ def generate_code(info):
295278

296279
code = code_template % dict(
297280
model_function=model_function,
298-
task_code=info['task_name']['code'],
299-
model_code=info['model_name']['code'],
281+
task_code=extract_or_empty_string(info, 'task_name', 'code'),
282+
model_code=extract_or_empty_string(info, 'model_name', 'code'),
300283
model_type=model_type_code,
301284
data_columns=data_columns,
302285
parameters=parameters,
@@ -309,11 +292,7 @@ def generate_code(info):
309292

310293

311294
def generate_test(info):
312-
# Model full name (Snake-case)
313-
model_function = [info['task_name']['code'], info['model_name']['code']]
314-
if info['model_type']['code']:
315-
model_function.append(info['model_type']['code'])
316-
model_function = '_'.join(model_function)
295+
model_function, _, _, _ = model_info(info)
317296

318297
# Read template for model tests
319298
with open(TEMPLATE_TEST, 'r') as f:
@@ -340,12 +319,7 @@ def main(info_fn):
340319
test = generate_test(info)
341320
output = docs + code
342321

343-
# Model full name (Snake-case)
344-
model_function = [info['task_name']['code'],
345-
info['model_name']['code']]
346-
if info['model_type']['code']:
347-
model_function.append(info['model_type']['code'])
348-
model_function = '_'.join(model_function)
322+
model_function, _, _, _ = model_info(info)
349323

350324
# Make directories if not exist
351325
if not PATH_OUTPUT.exists():

commons/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def model_info(info):
2+
task_name_code = info.get('task_name', {}).get('code')
3+
model_name_code = info.get('model_name', {}).get('code')
4+
model_type_code = info.get('model_type', {}).get('code')
5+
6+
# Model full name (Snake-case)
7+
model_function = []
8+
if task_name_code is not None and len(task_name_code) > 0:
9+
model_function.append(task_name_code)
10+
if model_name_code is not None and len(model_name_code) > 0:
11+
model_function.append(model_name_code)
12+
if model_type_code is not None and len(model_type_code) > 0:
13+
model_function.append(model_type_code)
14+
model_function = '_'.join(model_function)
15+
16+
if model_type_code is None:
17+
model_type_code = ''
18+
return model_function, task_name_code, model_name_code, model_type_code
19+
20+
# Prefix to preprocess_func
21+
def preprocess_func_prefix(info):
22+
task_name_code = info.get('task_name', {}).get('code')
23+
model_name_code = info.get('model_name', {}).get('code')
24+
model_type_code = info.get('model_type', {}).get('code')
25+
26+
preprocess_func_prefix = []
27+
if task_name_code:
28+
preprocess_func_prefix.append(task_name_code)
29+
else:
30+
preprocess_func_prefix.append(model_name_code)
31+
if model_type_code:
32+
preprocess_func_prefix.append(model_type_code)
33+
return '_'.join(preprocess_func_prefix)
34+
35+
def extract_or_empty_string(info, key, subkey):
36+
return info[key][subkey] if info[key][subkey] is not None else ""

0 commit comments

Comments
 (0)