Skip to content

Conversation

megemini
Copy link
Contributor

PR types

Others

PR changes

Others

Description

利用 multiprocessing 对 xdoctester 进行环境隔离,涉及文件:

  • tools/sampcd_processor_xdoctest.py : 增加 multiprocessing 的支持,每次 execute(run xdoctest 的 example) 的时候开一个新进程执行检查,并在每次检查前 patch xdoctest。
  • tools/test_sampcd_processor_xdoctest.py : 增加测试用例
  • python/paddle/static/nn/metric.py : 修改了其中的示例代码,用于验证此次 PR 的正确性。
    预期应该有:6 个 pass,1 个 skip,1 个 fail。

@SigureMo 请评审 ~

@paddle-bot
Copy link

paddle-bot bot commented Aug 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Aug 17, 2023
Comment on lines +1826 to +1847
Examples:

.. code-block:: python

>>> import numpy as np
>>> import paddle
>>> paddle.enable_static()
>>> data = paddle.static.data(name='X', shape=[None, 2, 28, 28], dtype='float32')
""",
'static_1': """
this is docstring...

Examples:

.. code-block:: python

>>> import numpy as np
>>> import paddle
>>> paddle.enable_static()
>>> data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')

""",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是动静混在一起才能测出来相互是否有干扰?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个怎么动静混在一起?写个例子我试一下吧 ~

这两段代码,在同一个 python shell 里面先后运行,就会出错,可以试一下~

In [1]:                     >>> import numpy as np
   ...:                     >>> import paddle
   ...:                     >>> paddle.enable_static()
   ...:                     >>> data = paddle.static.data(name='X', shape=[None, 2, 28, 28], dtype='float
   ...: 32')

In [2]:                     >>> import numpy as np
   ...:                     >>> import paddle
   ...:                     >>> paddle.enable_static()
   ...:                     >>> data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float
   ...: 32')
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 4
      2 import paddle
      3 paddle.enable_static()
----> 4 data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')

File ~/venv38/lib/python3.8/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/venv38/lib/python3.8/site-packages/paddle/fluid/wrapped_decorator.py:26, in wrap_decorator.<locals>.__impl__(func, *args, **kwargs)
     23 @decorator.decorator
     24 def __impl__(func, *args, **kwargs):
     25     wrapped_func = decorator_func(func)
---> 26     return wrapped_func(*args, **kwargs)

File ~/venv38/lib/python3.8/site-packages/paddle/fluid/framework.py:558, in _static_only_.<locals>.__impl__(*args, **kwargs)
    553 def __impl__(*args, **kwargs):
    554     assert not _non_static_mode(), (
    555         "In PaddlePaddle 2.x, we turn on dynamic graph mode by default, and '%s()' is only supported in static graph mode. So if you want to use this api, please call 'paddle.enable_static()' before this api to enter static graph mode."
    556         % func.__name__
    557     )
--> 558     return func(*args, **kwargs)

File ~/venv38/lib/python3.8/site-packages/paddle/static/input.py:102, in data(name, shape, dtype, lod_level)
     99         shape[i] = -1
    101 if dtype:
--> 102     return helper.create_global_variable(
    103         name=name,
    104         shape=shape,
    105         dtype=dtype,
    106         type=core.VarDesc.VarType.LOD_TENSOR,
    107         stop_gradient=True,
    108         lod_level=lod_level,
    109         is_data=True,
    110         need_check_feed=True,
    111     )
    112 else:
    113     return helper.create_global_variable(
    114         name=name,
    115         shape=shape,
   (...)
    121         need_check_feed=True,
    122     )

File ~/venv38/lib/python3.8/site-packages/paddle/fluid/layer_helper_base.py:453, in LayerHelperBase.create_global_variable(self, persistable, *args, **kwargs)
    443 def create_global_variable(self, persistable=False, *args, **kwargs):
    444     """
    445     create global variable, note that there is no initializer for this global variable.
    446     Args:
   (...)
    451     Returns(Variable): the created variable.
    452     """
--> 453     return self.main_program.global_block().create_var(
    454         *args, persistable=persistable, **kwargs)

File ~/venv38/lib/python3.8/site-packages/paddle/fluid/framework.py:3836, in Block.create_var(self, *args, **kwargs)
   3834     var = _varbase_creator(*args, **kwargs)
   3835 else:
-> 3836     var = Variable(block=self, *args, **kwargs)
   3837     if 'initializer' in kwargs:
   3838         kwargs['initializer'](var, self)

File ~/venv38/lib/python3.8/site-packages/paddle/fluid/framework.py:1458, in Variable.__init__(self, block, type, name, shape, dtype, lod_level, capacity, persistable, error_clip, stop_gradient, is_data, need_check_feed, belong_to_optimizer, **kwargs)
   1456         shape = tuple(shape)
   1457         if shape != old_shape:
-> 1458             raise ValueError(
   1459                 "Variable '{0}' has been created before. The previous "
   1460                 "shape is {1}, the new shape is {2}. They are not "
   1461                 "matched.".format(self.name, old_shape, shape)
   1462             )
   1463 if dtype is not None:
   1464     if is_new_var:

ValueError: Variable 'X' has been created before. The previous shape is (-1, 2, 28, 28), the new shape is (-1, 1, 28, 28). They are not matched.

@SigureMo
Copy link
Member

预期应该有:6 个 pass,1 个 skip,1 个 fail。

image

符合预期

@luotao1 luotao1 added the HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 label Aug 18, 2023
@luotao1 luotao1 removed the HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 label Aug 18, 2023
print(output)
'''
.. code-block:: python
:name: example-2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

拆成两个之后 copy-from 是不是会出问题?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是说这些全是用来测试的,之后全部恢复?

Copy link
Contributor Author

@megemini megemini Aug 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是用来测试的 ~ #56121 这个 PR 已经改了,最后以这个为准吧~

不过,感觉拆开会比较好,把这段代码放到 '''...''' 里面作为一段注释,不知道当初是为啥?!

另外,测试通过也有一点侥幸,仔细看了一下这段测试代码,ins_tag_weight 写错了,运行的时候 pruned 了~ 而且,fc_out 的 weight 和 bias 也没有固定,所以,结果有一定概率不对 ... ... 只是不影响 multiporcessing 的测试罢了 ... ...

这个在示例修改的时候还是要改一下 ~

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不过,感觉拆开会比较好,把这段代码放到 '''...''' 里面作为一段注释,不知道当初是为啥?!

我也觉得拆开比较好,可以拆开后在中文里顺便改一下 copy from,这个 PR 我觉得没啥问题了,恢复后就可以 LGTM

@megemini
Copy link
Contributor Author

Update 20230820

  • 增加 TIMEOUT 指令

    原来的方法是:对于运行时间进行计时,如果超过 TEST_TIMEOUT 则认为超时。

    • 优点:能够记录代码运行的实际时长。
    • 缺点:如果代码运行时间很长,可能把 CI 卡住。

    修改的方法是:根据默认的 TEST_TIMEOUT 或者 # doctest: +TIMEOUT(XX) 中的时间,如果超过这个时间没有获取到结果则认为超时。

    • 优点:不会卡住 CI 的运行,超过时间则退出运行。
    • 缺点:如果运行时间很久,不会获取到具体的运行时长。
  • 增加了 metric.py 中的示例,预期结果为:6 个 passed,1 个 skipped,1 个不能运行(语法错误),1 个 timeout。1 个 failed。

@SigureMo
Copy link
Member

1 个不能运行(语法错误)

语法错误也能捕获了嘛?这个好诶

不过 CI 里好像只能看到一个 parse error?好像看不到其他示例代码成功/失败的信息

image

@megemini
Copy link
Contributor Author

语法错误也能捕获了嘛?这个好诶

这个一直有,不过,是算做 nocode 一类 ~ 因为语法错误最后表现出来跟 nocode 一样 ~ 只能算是权益之法 ... ...

1 apis could not run test or don't have sample codes
Test docstring from: file *metric.py* line number *424*.

不过 CI 里好像只能看到一个 parse error?好像看不到其他示例代码成功/失败的信息

唉 ... ... 看了下 xdoctest 的源码,在 debug 的时候,会导入一个 ubelt,这里导入失败直接抛异常了 ... ... 本地没用 debug 方式所以没发现 ... ... 我把 ubelt 也加到 requirements 里面吧 ~

@SigureMo
Copy link
Member

SigureMo commented Aug 22, 2023

image

这个看起来没问题了,相关代码可以恢复了~

我把 ubelt 也加到 requirements 里面吧 ~

我看了下 ubelt 是 xdoctest docs 的依赖(xdoctest/docs/requirements.txt),按理说代码不应该依赖呀

本想看看是否哪个可选依赖里有这个(就是安装时候用 pip install xdoctest[all] 这种),但是可选依赖里是没有这个的 https://github.com/Erotemic/xdoctest/blob/main/setup.py#L208-L224

那是不是在 ubelt 后加一个注释说明其为 xdoctest 的相关依赖比较好,这样将来如果 xdoctest 退场可以保证 ubelt 可以一并清理

@megemini
Copy link
Contributor Author

那是不是在 ubelt 后加一个注释说明其为 xdoctest 的相关依赖比较好,这样将来如果 xdoctest 退场可以保证 ubelt 可以一并清理

嗯 昨晚还在想这个事儿呢 ~ 哈哈哈哈 🤣🤣🤣🤣🤣

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@luotao1 luotao1 merged commit ea4182d into PaddlePaddle:develop Aug 23, 2023
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
* [Change] make xdoctester multiprocessing

* [Add] add timeout directive

* [Fix] fix code-block

* [Fix] add ubelt requirements

* [Fix] patch xdoctest in __init__

* [Fix] codestyle

* [Change] restore metric.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants