Skip to content

Concrete function generation (for tflite) no longer works with new tensorflow and keras_nlp packages #11

@zhubarb

Description

@zhubarb

Expected Behavior

The below code segment should run:

@tf.function
def generate(prompt, max_length):
    return gpt2_lm.generate(prompt, max_length)

concrete_func = generate.get_concrete_function(tf.TensorSpec([], tf.string), 100)

Actual Behavior

Error:

Traceback (most recent call last):
  File "/snap/pycharm-community/383/plugins/python-ce/helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1251, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1221, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
    self._concrete_variable_creation_fn = tracing_compilation.trace_function(
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    concrete_function = _maybe_define_function(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    concrete_function = _create_concrete_function(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    traced_func_graph = func_graph_module.func_graph_from_py_func(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 52, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
    return api.converted_call(
           ^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileeqeqeph9.py", line 12, in tf__generate
    retval_ = ag__.converted_call(ag__.ld(gpt2_lm).generate, (ag__.ld(prompt), ag__.ld(max_length)), None, fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
    result = converted_f(*effective_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filem5y9usa_.py", line 165, in tf__generate
    ag__.if_stmt(ag__.ld(self).preprocessor is not None, if_body_4, else_body_4, get_state_4, set_state_4, ('outputs',), 1)
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1217, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1270, in _py_if_stmt
    return body() if cond else orelse()
           ^^^^^^
  File "/tmp/__autograph_generated_filem5y9usa_.py", line 160, in if_body_4
    outputs = [ag__.converted_call(ag__.ld(postprocess), (ag__.ld(x),), None, fscope) for x in ag__.ld(outputs)]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filem5y9usa_.py", line 160, in <listcomp>
    outputs = [ag__.converted_call(ag__.ld(postprocess), (ag__.ld(x),), None, fscope) for x in ag__.ld(outputs)]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 339, in converted_call
    return _call_unconverted(f, args, kwargs, options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 460, in _call_unconverted
    return f(*args)
           ^^^^^^^^
  File "/tmp/__autograph_generated_filem5y9usa_.py", line 111, in postprocess
    retval__3 = ag__.converted_call(ag__.ld(self).preprocessor.generate_postprocess, (ag__.ld(x),), None, fscope_3)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
    result = converted_f(*effective_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file05s_92kj.py", line 30, in tf__generate_postprocess
    token_ids = ag__.converted_call(ag__.ld(ops).convert_to_numpy, (ag__.ld(token_ids),), None, fscope)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
    return _call_unconverted(f, args, kwargs, options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 460, in _call_unconverted
    return f(*args)
           ^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/keras/src/ops/core.py", line 512, in convert_to_numpy
    return backend.convert_to_numpy(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/keras/src/backend/tensorflow/core.py", line 131, in convert_to_numpy
    return np.asarray(x)
           ^^^^^^^^^^^^^
  File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/tensorflow/python/framework/tensor.py", line 627, in __array__
    raise NotImplementedError(
NotImplementedError: in user code:
    File "/home/zhubarb/PycharmProjects/keras_nlp_tflite/main.py", line 24, in generate  *
        return gpt2_lm.generate(prompt, max_length)
    File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/keras_nlp/src/models/causal_lm.py", line 371, in postprocess  *
        return self.preprocessor.generate_postprocess(x)
    File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/keras_nlp/src/models/gpt2/gpt2_causal_lm_preprocessor.py", line 178, in generate_postprocess  *
        token_ids = ops.convert_to_numpy(token_ids)
    File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/keras/src/ops/core.py", line 512, in convert_to_numpy  **
        return backend.convert_to_numpy(x)
    File "/home/zhubarb/miniconda3/envs/keras_nlp_tflite/lib/python3.11/site-packages/keras/src/backend/tensorflow/core.py", line 131, in convert_to_numpy
        return np.asarray(x)
    NotImplementedError: Cannot convert a symbolic tf.Tensor (StatefulPartitionedCall:1) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

Steps to Reproduce the Problem

  1. On a google colab notebook, run !pip install keras_nlp, instead of !pip install -q git+https://github.com/keras-team/keras-nlp.git@google-io-2023 tensorflow-text==2.12
  2. On the concrete function generation step, you will see the error.

Specifications

  • Version: Python 3.11
    print(tf.version)
    print(keras.version)
    print(keras_nlp.version)
    2.16.1
    3.3.3
    0.11.1
  • Platform: Google Colab

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions