Skip to content

Commit

Permalink
Respects user-provide output_file name for generated subgraphs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568829158
  • Loading branch information
aferludin authored and tensorflower-gardener committed Sep 27, 2023
1 parent 3375033 commit d6298ab
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
11 changes: 8 additions & 3 deletions tensorflow_gnn/data/unigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,11 +774,16 @@ def __init__(self,
except ValueError:
self.file_format = "tfrecord"

@property
def key_value_format(self) -> bool:
return self.file_format == "sstable"

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
kwargs = get_sharded_pattern_args(self.file_pattern)
# TODO(b/302278337): unify naming conventions for all formats.
if self.file_format == "tfrecord":
return (pcoll
| beam.io.tfrecordio.WriteToTFRecord(coder=self.coder, **kwargs))
return pcoll | beam.io.tfrecordio.WriteToTFRecord(
coder=self.coder, **get_sharded_pattern_args(self.file_pattern)
)
# Placeholder for Google-internal file writes
else:
raise NotImplementedError(
Expand Down
14 changes: 5 additions & 9 deletions tensorflow_gnn/experimental/sampler/beam/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,11 @@ def app_main(argv) -> None:
program_pb, inputs, feeds=feeds, artifacts_path=artifacts_path
)
# results are tuple: example_id to tf.Example with graph tensors.
coder = beam.coders.ProtoCoder(tf.train.Example)
_ = (
examples
| 'DropExampleId' >> beam.Values()
| 'WriteToTFRecord'
>> beam.io.WriteToTFRecord(
os.path.join(output_dir, 'examples.tfrecord'), coder=coder
)
)
examples_writer = unigraph.WriteTable(FLAGS.output_samples)
if not examples_writer.key_value_format:
examples = examples | 'DropExampleId' >> beam.Values()
_ = examples | 'WriteExamples' >> examples_writer

logging.info('Pipeline complete')


Expand Down

0 comments on commit d6298ab

Please sign in to comment.